Skip to content

Commit 0a278f2

Browse files
authored
Merge pull request #329 from inocsin/clamp_max_min
add clamp_min/clamp_max converter
2 parents 62b077e + 684a318 commit 0a278f2

File tree

2 files changed

+79
-26
lines changed

2 files changed

+79
-26
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,21 @@ nvinfer1::ILayer* add_elementwise(
6868
return ele;
6969
}
7070

71+
nvinfer1::ITensor* clamp_util(
72+
ConversionCtx* ctx,
73+
const torch::jit::Node* n,
74+
nvinfer1::ITensor* self,
75+
float limit,
76+
nvinfer1::ElementWiseOperation op_type,
77+
std::string str) {
78+
nvinfer1::ITensor* clamp_layer_out = self;
79+
auto limitTensor = tensor_to_const(ctx, torch::tensor({limit}));
80+
auto limit_layer = add_elementwise(ctx, op_type, clamp_layer_out, limitTensor, util::node_info(n) + str);
81+
TRTORCH_CHECK(limit_layer, "Unable to create elementwise " << str << " layer for node: " << *n);
82+
clamp_layer_out = limit_layer->getOutput(0);
83+
return clamp_layer_out;
84+
}
85+
7186
auto element_wise_registrations TRTORCH_UNUSED =
7287
RegisterNodeConversionPatterns()
7388
.pattern({"aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> "
@@ -145,38 +160,58 @@ auto element_wise_registrations TRTORCH_UNUSED =
145160
return true;
146161
}})
147162
.pattern({"aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)",
163+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
164+
// Compute min(max(min_threshold, input), max_threshold)
165+
auto self = args[0].ITensorOrFreeze(ctx);
166+
auto clamp_layer_out = self;
167+
168+
if (args[1].isIValue() && args[1].IValue()->isScalar() && args[2].isIValue() &&
169+
args[2].IValue()->isScalar()) {
170+
auto alpha = args[1].unwrapToScalar().to<float>();
171+
auto beta = args[2].unwrapToScalar().to<float>();
172+
auto clip_layer = ctx->net->addActivation(*self, nvinfer1::ActivationType::kCLIP);
173+
TRTORCH_CHECK(clip_layer, "Unable to create clip layer for node: " << *n);
174+
clip_layer->setAlpha(alpha);
175+
clip_layer->setBeta(beta);
176+
clamp_layer_out = clip_layer->getOutput(0);
177+
} else if (args[1].isIValue() && args[1].IValue()->isScalar()) {
178+
auto limit = args[1].unwrapToScalar().to<float>();
179+
clamp_layer_out = clamp_util(ctx, n, self, limit, nvinfer1::ElementWiseOperation::kMAX, "_max");
180+
} else if (args[2].isIValue() && args[2].IValue()->isScalar()) {
181+
auto limit = args[2].unwrapToScalar().to<float>();
182+
clamp_layer_out = clamp_util(ctx, n, self, limit, nvinfer1::ElementWiseOperation::kMIN, "_min");
183+
}
184+
185+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], clamp_layer_out);
186+
LOG_DEBUG("Clamp layer output tensor shape: " << out->getDimensions());
187+
return true;
188+
}})
189+
.pattern({"aten::clamp_min(Tensor self, Scalar min) -> (Tensor)",
148190
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
149191
// Compute min(max(min_threshold, input), max_threshold)
150192
auto self = args[0].ITensorOrFreeze(ctx);
151193
auto clamp_layer_out = self;
152194
if (args[1].isIValue() && args[1].IValue()->isScalar()) {
153-
auto minScalar = args[1].unwrapToScalar().to<float>();
154-
auto minTensor = tensor_to_const(ctx, torch::tensor({minScalar}));
155-
auto max_layer = add_elementwise(
156-
ctx,
157-
nvinfer1::ElementWiseOperation::kMAX,
158-
clamp_layer_out,
159-
minTensor,
160-
util::node_info(n) + std::string("_max"));
161-
TRTORCH_CHECK(max_layer, "Unable to create elementwise max layer for node: " << *n);
162-
clamp_layer_out = max_layer->getOutput(0);
195+
auto limit = args[1].unwrapToScalar().to<float>();
196+
clamp_layer_out = clamp_util(ctx, n, self, limit, nvinfer1::ElementWiseOperation::kMAX, "_max");
163197
}
164198

165-
if (args[2].isIValue() && args[2].IValue()->isScalar()) {
166-
auto maxScalar = args[2].unwrapToScalar().to<float>();
167-
auto maxTensor = tensor_to_const(ctx, torch::tensor({maxScalar}));
168-
auto min_layer = add_elementwise(
169-
ctx,
170-
nvinfer1::ElementWiseOperation::kMIN,
171-
clamp_layer_out,
172-
maxTensor,
173-
util::node_info(n) + std::string("_min"));
174-
TRTORCH_CHECK(min_layer, "Unable to create elementwise min layer for node: " << *n);
175-
clamp_layer_out = min_layer->getOutput(0);
199+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], clamp_layer_out);
200+
LOG_DEBUG("clamp_min layer output tensor shape: " << out->getDimensions());
201+
return true;
202+
}})
203+
.pattern({"aten::clamp_max(Tensor self, Scalar max) -> (Tensor)",
204+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
205+
// Compute min(max(min_threshold, input), max_threshold)
206+
auto self = args[0].ITensorOrFreeze(ctx);
207+
auto clamp_layer_out = self;
208+
if (args[1].isIValue() && args[1].IValue()->isScalar()) {
209+
auto limit = args[1].unwrapToScalar().to<float>();
210+
clamp_layer_out = clamp_util(ctx, n, self, limit, nvinfer1::ElementWiseOperation::kMIN, "_min");
176211
}
177212

178213
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], clamp_layer_out);
179-
LOG_DEBUG("Clamp layer output tensor shape: " << out->getDimensions());
214+
LOG_DEBUG("clamp_max layer output tensor shape: " << out->getDimensions());
180215
return true;
181216
}})
182217
.pattern({"aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> "

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ TEST(Converters, ATenRsubWithScalarConvertsCorrectly) {
250250
TEST(Converters, ATenClampMinConvertsCorrectly) {
251251
const auto graph = R"IR(
252252
graph(%x.1 : Tensor):
253-
%2 : int = prim::Constant[value=-2]()
253+
%2 : float = prim::Constant[value=1.5]()
254254
%3 : None = prim::Constant()
255255
%4 : Tensor = aten::clamp(%x.1, %2, %3)
256256
return (%4))IR";
@@ -260,7 +260,7 @@ TEST(Converters, ATenClampMinConvertsCorrectly) {
260260
TEST(Converters, ATenClampMaxConvertsCorrectly) {
261261
const auto graph = R"IR(
262262
graph(%x.1 : Tensor):
263-
%2 : int = prim::Constant[value=3]()
263+
%2 : float = prim::Constant[value=3.5]()
264264
%3 : None = prim::Constant()
265265
%4 : Tensor = aten::clamp(%x.1, %3, %2)
266266
return (%4))IR";
@@ -270,13 +270,31 @@ TEST(Converters, ATenClampMaxConvertsCorrectly) {
270270
TEST(Converters, ATenClampMinMaxConvertsCorrectly) {
271271
const auto graph = R"IR(
272272
graph(%x.1 : Tensor):
273-
%2 : int = prim::Constant[value=3]()
274-
%3 : int = prim::Constant[value=-2]()
273+
%2 : float = prim::Constant[value=3.5]()
274+
%3 : float = prim::Constant[value=1.5]()
275275
%4 : Tensor = aten::clamp(%x.1, %3, %2)
276276
return (%4))IR";
277277
pointwise_test_helper(graph, true);
278278
}
279279

280+
TEST(Converters, ATenClampMinimumConvertsCorrectly) {
281+
const auto graph = R"IR(
282+
graph(%x.1 : Tensor):
283+
%2 : float = prim::Constant[value=2.5]()
284+
%4 : Tensor = aten::clamp_min(%x.1, %2)
285+
return (%4))IR";
286+
pointwise_test_helper(graph, true);
287+
}
288+
289+
TEST(Converters, ATenClampMaximumConvertsCorrectly) {
290+
const auto graph = R"IR(
291+
graph(%x.1 : Tensor):
292+
%2 : float = prim::Constant[value=2.5]()
293+
%4 : Tensor = aten::clamp_max(%x.1, %2)
294+
return (%4))IR";
295+
pointwise_test_helper(graph, true);
296+
}
297+
280298
TEST(Converters, ATenGreaterThanConvertsCorrectly) {
281299
const auto graph = R"IR(
282300
graph(%0 : Tensor, %1 : Tensor):

0 commit comments

Comments
 (0)