@@ -68,6 +68,21 @@ nvinfer1::ILayer* add_elementwise(
68
68
return ele;
69
69
}
70
70
71
+ nvinfer1::ITensor* clamp_util (
72
+ ConversionCtx* ctx,
73
+ const torch::jit::Node* n,
74
+ nvinfer1::ITensor* self,
75
+ float limit,
76
+ nvinfer1::ElementWiseOperation op_type,
77
+ std::string str) {
78
+ nvinfer1::ITensor* clamp_layer_out = self;
79
+ auto limitTensor = tensor_to_const (ctx, torch::tensor ({limit}));
80
+ auto limit_layer = add_elementwise (ctx, op_type, clamp_layer_out, limitTensor, util::node_info (n) + str);
81
+ TRTORCH_CHECK (limit_layer, " Unable to create elementwise " << str << " layer for node: " << *n);
82
+ clamp_layer_out = limit_layer->getOutput (0 );
83
+ return clamp_layer_out;
84
+ }
85
+
71
86
auto element_wise_registrations TRTORCH_UNUSED =
72
87
RegisterNodeConversionPatterns ()
73
88
.pattern({" aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> "
@@ -145,38 +160,58 @@ auto element_wise_registrations TRTORCH_UNUSED =
145
160
return true ;
146
161
}})
147
162
.pattern({" aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)" ,
163
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
164
+ // Compute min(max(min_threshold, input), max_threshold)
165
+ auto self = args[0 ].ITensorOrFreeze (ctx);
166
+ auto clamp_layer_out = self;
167
+
168
+ if (args[1 ].isIValue () && args[1 ].IValue ()->isScalar () && args[2 ].isIValue () &&
169
+ args[2 ].IValue ()->isScalar ()) {
170
+ auto alpha = args[1 ].unwrapToScalar ().to <float >();
171
+ auto beta = args[2 ].unwrapToScalar ().to <float >();
172
+ auto clip_layer = ctx->net ->addActivation (*self, nvinfer1::ActivationType::kCLIP );
173
+ TRTORCH_CHECK (clip_layer, " Unable to create clip layer for node: " << *n);
174
+ clip_layer->setAlpha (alpha);
175
+ clip_layer->setBeta (beta);
176
+ clamp_layer_out = clip_layer->getOutput (0 );
177
+ } else if (args[1 ].isIValue () && args[1 ].IValue ()->isScalar ()) {
178
+ auto limit = args[1 ].unwrapToScalar ().to <float >();
179
+ clamp_layer_out = clamp_util (ctx, n, self, limit, nvinfer1::ElementWiseOperation::kMAX , " _max" );
180
+ } else if (args[2 ].isIValue () && args[2 ].IValue ()->isScalar ()) {
181
+ auto limit = args[2 ].unwrapToScalar ().to <float >();
182
+ clamp_layer_out = clamp_util (ctx, n, self, limit, nvinfer1::ElementWiseOperation::kMIN , " _min" );
183
+ }
184
+
185
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], clamp_layer_out);
186
+ LOG_DEBUG (" Clamp layer output tensor shape: " << out->getDimensions ());
187
+ return true ;
188
+ }})
189
+ .pattern({" aten::clamp_min(Tensor self, Scalar min) -> (Tensor)" ,
148
190
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
149
191
// Compute min(max(min_threshold, input), max_threshold)
150
192
auto self = args[0 ].ITensorOrFreeze (ctx);
151
193
auto clamp_layer_out = self;
152
194
if (args[1 ].isIValue () && args[1 ].IValue ()->isScalar ()) {
153
- auto minScalar = args[1 ].unwrapToScalar ().to <float >();
154
- auto minTensor = tensor_to_const (ctx, torch::tensor ({minScalar}));
155
- auto max_layer = add_elementwise (
156
- ctx,
157
- nvinfer1::ElementWiseOperation::kMAX ,
158
- clamp_layer_out,
159
- minTensor,
160
- util::node_info (n) + std::string (" _max" ));
161
- TRTORCH_CHECK (max_layer, " Unable to create elementwise max layer for node: " << *n);
162
- clamp_layer_out = max_layer->getOutput (0 );
195
+ auto limit = args[1 ].unwrapToScalar ().to <float >();
196
+ clamp_layer_out = clamp_util (ctx, n, self, limit, nvinfer1::ElementWiseOperation::kMAX , " _max" );
163
197
}
164
198
165
- if (args[2 ].isIValue () && args[2 ].IValue ()->isScalar ()) {
166
- auto maxScalar = args[2 ].unwrapToScalar ().to <float >();
167
- auto maxTensor = tensor_to_const (ctx, torch::tensor ({maxScalar}));
168
- auto min_layer = add_elementwise (
169
- ctx,
170
- nvinfer1::ElementWiseOperation::kMIN ,
171
- clamp_layer_out,
172
- maxTensor,
173
- util::node_info (n) + std::string (" _min" ));
174
- TRTORCH_CHECK (min_layer, " Unable to create elementwise min layer for node: " << *n);
175
- clamp_layer_out = min_layer->getOutput (0 );
199
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], clamp_layer_out);
200
+ LOG_DEBUG (" clamp_min layer output tensor shape: " << out->getDimensions ());
201
+ return true ;
202
+ }})
203
+ .pattern({" aten::clamp_max(Tensor self, Scalar max) -> (Tensor)" ,
204
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
205
+ // Compute min(max(min_threshold, input), max_threshold)
206
+ auto self = args[0 ].ITensorOrFreeze (ctx);
207
+ auto clamp_layer_out = self;
208
+ if (args[1 ].isIValue () && args[1 ].IValue ()->isScalar ()) {
209
+ auto limit = args[1 ].unwrapToScalar ().to <float >();
210
+ clamp_layer_out = clamp_util (ctx, n, self, limit, nvinfer1::ElementWiseOperation::kMIN , " _min" );
176
211
}
177
212
178
213
auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], clamp_layer_out);
179
- LOG_DEBUG (" Clamp layer output tensor shape: " << out->getDimensions ());
214
+ LOG_DEBUG (" clamp_max layer output tensor shape: " << out->getDimensions ());
180
215
return true ;
181
216
}})
182
217
.pattern({" aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> "
0 commit comments