Skip to content

Commit d90e0d1

Browse files
samurdhikarukevinch-nv
authored andcommitted
Fix Normalize_TRT plugin segfault
Signed-off-by: Kevin Chen <kevinch@nvidia.com>
1 parent 7818985 commit d90e0d1

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

plugin/normalizePlugin/normalizePlugin.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ const char* NORMALIZE_PLUGIN_NAME{"Normalize_TRT"};
3535
PluginFieldCollection NormalizePluginCreator::mFC{};
3636
std::vector<PluginField> NormalizePluginCreator::mPluginAttributes;
3737

38-
Normalize::Normalize(const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps)
38+
Normalize::Normalize(Weights const* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps)
3939
: acrossSpatial(acrossSpatial)
4040
, channelShared(channelShared)
4141
, eps(eps)
@@ -44,11 +44,13 @@ Normalize::Normalize(const Weights* weights, int nbWeights, bool acrossSpatial,
4444
PLUGIN_VALIDATE(nbWeights == 1);
4545
PLUGIN_VALIDATE(weights[0].count >= 1);
4646
mWeights = copyToDevice(weights[0].values, weights[0].count);
47+
mScalarScale = static_cast<float const*>(weights[0].values)[0];
4748
}
4849

4950
Normalize::Normalize(
50-
const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W)
51-
: acrossSpatial(acrossSpatial)
51+
Weights const* weights, int nbWeights, float scalarScale, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W)
52+
: mScalarScale(scalarScale)
53+
, acrossSpatial(acrossSpatial)
5254
, channelShared(channelShared)
5355
, eps(eps)
5456
, C(C)
@@ -74,6 +76,7 @@ Normalize::Normalize(const void* buffer, size_t length)
7476

7577
mNbWeights = read<int>(d);
7678
int count = read<int>(d);
79+
std::memcpy(&mScalarScale, d, sizeof(float));
7780
mWeights = deserializeToDevice(d, count);
7881
PLUGIN_VALIDATE(d == a + length);
7982
}
@@ -111,8 +114,19 @@ int Normalize::enqueue(
111114
{
112115
const void* inputData = inputs[0];
113116
void* outputData = outputs[0];
114-
pluginStatus_t status = normalizeInference(stream, mCublas, acrossSpatial, channelShared, batchSize, C, H, W, eps,
115-
static_cast<const float*>(mWeights.values), inputData, outputData, workspace);
117+
118+
pluginStatus_t status;
119+
120+
if(acrossSpatial && channelShared) // Since cublasPointerMode_t is CUBLAS_POINTER_MODE_HOST, scale should be on the host
121+
{
122+
status = normalizeInference(stream, mCublas, acrossSpatial, channelShared, batchSize, C, H, W, eps,
123+
&mScalarScale, inputData, outputData, workspace);
124+
}
125+
else // No risk of device pointers being passed to cublas as alpha or beta
126+
{
127+
status = normalizeInference(stream, mCublas, acrossSpatial, channelShared, batchSize, C, H, W, eps,
128+
static_cast<float const*>(mWeights.values), inputData, outputData, workspace);
129+
}
116130

117131
return status;
118132
}
@@ -254,7 +268,7 @@ IPluginV2Ext* Normalize::clone() const noexcept
254268
try
255269
{
256270
// Create a new instance
257-
IPluginV2Ext* plugin = new Normalize(&mWeights, mNbWeights, acrossSpatial, channelShared, eps, C, H, W);
271+
IPluginV2Ext* plugin = new Normalize(&mWeights, mNbWeights, mScalarScale, acrossSpatial, channelShared, eps, C, H, W);
258272

259273
// Set the namespace
260274
plugin->setPluginNamespace(mPluginNamespace.c_str());

plugin/normalizePlugin/normalizePlugin.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ namespace plugin
3131
class Normalize : public IPluginV2Ext
3232
{
3333
public:
34-
Normalize(const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps);
34+
Normalize(Weights const* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps);
3535

3636
Normalize(
37-
const Weights* weights, int nbWeights, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W);
37+
Weights const* weights, int nbWeights, float scalarScale, bool acrossSpatial, bool channelShared, float eps, int C, int H, int W);
3838

3939
Normalize(const void* buffer, size_t length);
4040

@@ -93,8 +93,9 @@ class Normalize : public IPluginV2Ext
9393

9494
cublasHandle_t mCublas;
9595

96-
Weights mWeights{};
96+
Weights mWeights{}; // mWeights.values is on the device
9797
int mNbWeights{};
98+
float mScalarScale{}; // keep track of scale on the host (for when channelShared is true)
9899
bool acrossSpatial{};
99100
bool channelShared{};
100101
float eps{};

0 commit comments

Comments
 (0)