Skip to content

Commit d8151d9

Browse files
authored
Merge pull request #249 from NVIDIA/add_alpha
Fix add ops with alpha multiplier
2 parents 91bf074 + b0cb9b4 commit d8151d9

File tree

2 files changed

+41
-15
lines changed

2 files changed

+41
-15
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
7979
auto scalar = args[2].unwrapToScalar().to<float>();
8080

8181
if (1 != scalar) {
82-
auto scaleW = Weights(ctx, scalar);
83-
auto unuse = Weights();
84-
// IScaleLayer assert shift, scale and power to have
85-
// the same dtype
86-
auto scaleLayer = ctx->net->addScale(
87-
*other, nvinfer1::ScaleMode::kUNIFORM, unuse.data, scaleW.data, unuse.data);
88-
TRTORCH_CHECK(scaleLayer, "Unable to create scale layer from node: " << *n);
82+
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
83+
auto scaleLayer = add_elementwise(
84+
ctx,
85+
nvinfer1::ElementWiseOperation::kPROD,
86+
other,
87+
alphaTensor,
88+
util::node_info(n) + std::string("_AlphaMultiplier"));
89+
TRTORCH_CHECK(scaleLayer, "Unable to create alpha*input layer from node: " << *n);
8990
other = scaleLayer->getOutput(0);
9091
}
9192

@@ -107,13 +108,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
107108
auto scalar = args[2].unwrapToScalar().to<float>();
108109

109110
if (1 != scalar) {
110-
auto scaleW = Weights(ctx, scalar);
111-
auto unuse = Weights();
112-
// IScaleLayer assert shift, scale and power to have
113-
// the same dtype
114-
auto scaleLayer = ctx->net->addScale(
115-
*other, nvinfer1::ScaleMode::kUNIFORM, unuse.data, scaleW.data, unuse.data);
116-
TRTORCH_CHECK(scaleLayer, "Unable to create scale layer from node: " << *n);
111+
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
112+
auto scaleLayer = add_elementwise(
113+
ctx,
114+
nvinfer1::ElementWiseOperation::kPROD,
115+
other,
116+
alphaTensor,
117+
util::node_info(n) + std::string("_AlphaMultiplier"));
118+
TRTORCH_CHECK(scaleLayer, "Unable to create alpha*input layer from node: " << *n);
117119
other = scaleLayer->getOutput(0);
118120
}
119121

tests/core/converters/test_element_wise.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,30 @@ TEST(Converters, ATenAddConvertsCorrectly) {
5252
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
5353
}
5454

55+
TEST(Converters, ATenAddWithAlphaConvertsCorrectly) {
56+
const auto graph = R"IR(
57+
graph(%0 : Tensor, %1 : Tensor):
58+
%2 : float = prim::Constant[value=3.2]()
59+
%3 : Tensor = aten::add(%0, %1, %2)
60+
return (%3))IR";
61+
pointwise_test_helper(graph, false);
62+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
63+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
64+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
65+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
66+
}
67+
68+
TEST(Converters, ATenAddImplicitWithAlphaConvertsCorrectly) {
69+
const auto graph = R"IR(
70+
graph(%0 : Tensor, %1 : Tensor):
71+
%2 : float = prim::Constant[value=7.6]()
72+
%3 : Tensor = aten::add_(%0, %1, %2)
73+
return (%3))IR";
74+
pointwise_test_helper(graph, false);
75+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
76+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
77+
}
78+
5579
TEST(Converters, ATenSubConvertsCorrectly) {
5680
const auto graph = R"IR(
5781
graph(%0 : Tensor, %1 : Tensor):
@@ -118,4 +142,4 @@ TEST(Converters, ATenAddWithScalarConvertsCorrectly) {
118142
%3 : Tensor = aten::add(%0, %scalar, %2)
119143
return (%3))IR";
120144
pointwise_test_helper(graph, true);
121-
}
145+
}

0 commit comments

Comments
 (0)