Skip to content

Commit b9ddc86

Browse files
committed
support div.Scalar(Tensor self, Scalar other)
Signed-off-by: inocsin <[email protected]>
1 parent 4d3ac4f commit b9ddc86

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,21 @@ auto element_wise_registrations TRTORCH_UNUSED =
213213
div->setName(util::node_info(n).c_str());
214214
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
215215

216+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
217+
return true;
218+
}})
219+
.pattern({"aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)",
220+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
221+
// TODO: Remove with functionalization
222+
auto self = args[0].ITensorOrFreeze(ctx);
223+
auto otherScalar = args[1].unwrapToScalar().to<float>();
224+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
225+
auto div =
226+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
227+
TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);
228+
229+
div->setName(util::node_info(n).c_str());
230+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
216231
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
217232
return true;
218233
}})
@@ -229,6 +244,21 @@ auto element_wise_registrations TRTORCH_UNUSED =
229244
div->setName(util::node_info(n).c_str());
230245
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
231246

247+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
248+
return true;
249+
}})
250+
.pattern({"aten::div_.Scalar(Tensor self, Scalar other) -> (Tensor)",
251+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
252+
// TODO: Remove with functionalization
253+
auto self = args[0].ITensorOrFreeze(ctx);
254+
auto otherScalar = args[1].unwrapToScalar().to<float>();
255+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
256+
auto div =
257+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
258+
TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);
259+
260+
div->setName(util::node_info(n).c_str());
261+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
232262
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
233263
return true;
234264
}})

core/conversion/evaluators/aten.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ auto aten_registrations TRTORCH_UNUSED =
426426
}
427427
},
428428
EvalOptions().validSchemas({
429-
"aten::div.Scalar(Scalar a, Scalar b) -> (float)",
429+
"aten::div.float(float a, float b) -> (float)",
430+
"aten::div.int(int a, int b) -> (int)",
430431
})})
431432
.evaluator({c10::Symbol::fromQualString("aten::floordiv"),
432433
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,15 @@ TEST(Converters, ATenDivConvertsCorrectly) {
123123
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
124124
}
125125

126+
TEST(Converters, ATenDivWithScalarConvertsCorrectly) {
127+
const auto graph = R"IR(
128+
graph(%0 : Tensor):
129+
%scalar : float = prim::Constant[value=2.4]()
130+
%1 : Tensor = aten::div(%0, %scalar)
131+
return (%1))IR";
132+
pointwise_test_helper(graph, true);
133+
}
134+
126135
TEST(Converters, ATenPowTensorConvertsCorrectly) {
127136
const auto graph = R"IR(
128137
graph(%x.1 : Tensor, %x2.1 : Tensor):

0 commit comments

Comments
 (0)