Skip to content

Commit f179855

Browse files
[feat] Add partial converter support for aten::linalg_norm (#1426)
* Add partial support for linalg_norm * Add torch_executed_ops suggestion
1 parent b5bcccf commit f179855

File tree

2 files changed

+146
-29
lines changed

2 files changed

+146
-29
lines changed

core/conversion/converters/impl/normalize.cpp

Lines changed: 80 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,47 @@ void create_plugin(
5353
LOG_DEBUG("Normalize layer output tensor shape: " << layer_output->getDimensions());
5454
}
5555

56+
int32_t axes_mask_from_axes_values(
57+
const torch::jit::Node* n,
58+
int32_t nb_dims,
59+
const std::vector<int64_t>& axes_values) {
60+
int32_t axes_mask = 0;
61+
for (size_t i = 0UL; i < axes_values.size(); ++i) {
62+
auto axis = axes_values[i];
63+
if (axis < 0) {
64+
axis += nb_dims;
65+
}
66+
TORCHTRT_CHECK(
67+
axis < nb_dims, util::node_info(n) << " axis " << i << " with value: " << axis << " exceeds input rank");
68+
axes_mask += 1 << axis;
69+
}
70+
return axes_mask;
71+
}
72+
73+
nvinfer1::ITensor* frobenius_norm(
74+
ConversionCtx* ctx,
75+
const torch::jit::Node* n,
76+
nvinfer1::ITensor* self,
77+
int32_t axes_mask,
78+
bool keep_dims) {
79+
auto squared_layer =
80+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, self, util::node_info(n) + "_squared");
81+
TORCHTRT_CHECK(squared_layer, "Unabled to create square layer from node: " << *n);
82+
auto squared_output = squared_layer->getOutput(0);
83+
84+
auto sum_layer = ctx->net->addReduce(*squared_output, nvinfer1::ReduceOperation::kSUM, axes_mask, keep_dims);
85+
TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);
86+
sum_layer->setName((util::node_info(n) + "_sum").c_str());
87+
auto sum_output = sum_layer->getOutput(0);
88+
LOG_DEBUG("SUM SHAPE: " << sum_output->getDimensions());
89+
90+
auto sqrt_layer = ctx->net->addUnary(*sum_output, nvinfer1::UnaryOperation::kSQRT);
91+
TORCHTRT_CHECK(sqrt_layer, "Unable to create sqrt layer from node: " << *n);
92+
sqrt_layer->setName((util::node_info(n) + "_sqrt").c_str());
93+
auto sqrt_output = sqrt_layer->getOutput(0);
94+
return sqrt_output;
95+
}
96+
5697
auto normalize_registrations TORCHTRT_UNUSED =
5798
RegisterNodeConversionPatterns()
5899
.pattern(
@@ -79,37 +120,48 @@ auto normalize_registrations TORCHTRT_UNUSED =
79120
auto axes_values = args[1].unwrapToIntList().vec();
80121
auto keep_dims = args[2].unwrapToBool();
81122

82-
int32_t axes_mask = 0;
83-
auto self_nb_dims = self->getDimensions().nbDims;
84-
for (size_t i = 0UL; i < axes_values.size(); ++i) {
85-
auto axis = axes_values[i];
86-
if (axis < 0) {
87-
axis += self_nb_dims;
88-
}
89-
TORCHTRT_CHECK(
90-
axis < self_nb_dims,
91-
"aten::frobenius_norm axis: " << i << " with value: " << axis << " exceeds input rank");
92-
axes_mask += 1 << axis;
93-
}
123+
auto axes_mask = axes_mask_from_axes_values(n, self->getDimensions().nbDims, axes_values);
94124

95-
auto squared_layer = add_elementwise(
96-
ctx, nvinfer1::ElementWiseOperation::kPROD, self, self, util::node_info(n) + "_squared");
97-
TORCHTRT_CHECK(squared_layer, "Unabled to create square layer from node: " << *n);
98-
auto squared_output = squared_layer->getOutput(0);
99-
100-
auto sum_layer =
101-
ctx->net->addReduce(*squared_output, nvinfer1::ReduceOperation::kSUM, axes_mask, keep_dims);
102-
TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);
103-
sum_layer->setName((util::node_info(n) + "_sum").c_str());
104-
auto sum_output = sum_layer->getOutput(0);
105-
106-
auto sqrt_layer = ctx->net->addUnary(*sum_output, nvinfer1::UnaryOperation::kSQRT);
107-
TORCHTRT_CHECK(sqrt_layer, "Unable to create sqrt layer from node: " << *n);
108-
sqrt_layer->setName((util::node_info(n) + "_sqrt").c_str());
109-
auto sqrt_output = sqrt_layer->getOutput(0);
125+
auto norm = frobenius_norm(ctx, n, self, axes_mask, keep_dims);
126+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], norm);
127+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
128+
return true;
129+
}})
130+
.pattern(
131+
{"aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, int? dtype=None) -> (Tensor)",
132+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
133+
// https://pytorch.org/docs/stable/generated/torch.linalg.norm.html
134+
auto self = args[0].ITensorOrFreeze(ctx);
135+
TORCHTRT_CHECK(
136+
args[1].IValue()->isNone(),
137+
"aten::linalg_norm converter does not yet support non-None 'ord' arguments. Add aten::linalg_norm to torch_executed_ops to force it to fallback.");
138+
auto keep_dims = args[3].unwrapToBool();
139+
auto self_nb_dims = self->getDimensions().nbDims;
110140

111-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], sqrt_layer->getOutput(0));
141+
if (!args.back().IValue()->isNone()) {
142+
// If specified, the input tensor is cast to dtype before performing the operation, and the returned
143+
// tensor’s type will be dtype
144+
auto dtype = args.back().unwrapToScalar().to<int64_t>();
145+
auto trt_dtype = util::ScalarTypeToTRTDataType(static_cast<at::ScalarType>(dtype));
146+
self = castITensor(ctx, self, trt_dtype);
147+
}
112148

149+
int32_t axes_mask = 0;
150+
if (args[2].IValue()->isNone()) {
151+
// If dim= None and ord= None, self will be flattened to 1D and the 2-norm of the resulting vector will
152+
// be computed.
153+
axes_mask = 1;
154+
keep_dims = true; // the single output dim is always preserved
155+
auto flatten_layer = ctx->net->addShuffle(*self);
156+
TORCHTRT_CHECK(flatten_layer, "Unable to create shuffle layer from node: " << *n);
157+
flatten_layer->setReshapeDimensions(util::toDims(std::vector<int64_t>({-1})));
158+
flatten_layer->setName((util::node_info(n) + "_flatten").c_str());
159+
self = flatten_layer->getOutput(0);
160+
} else {
161+
axes_mask = axes_mask_from_axes_values(n, self_nb_dims, args[2].unwrapToIntList().vec());
162+
}
163+
auto norm = frobenius_norm(ctx, n, self, axes_mask, keep_dims);
164+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], norm);
113165
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
114166
return true;
115167
}});

tests/core/conversion/converters/test_normalize.cpp

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,69 @@ TEST(Converters, ATenFrobeniusNormMatrix) {
138138
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x});
139139

140140
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
141-
}
141+
}
142+
143+
TEST(Converters, ATenLinAlgNorm_None) {
144+
const auto graph = R"IR(
145+
graph(%x : Tensor):
146+
%none : NoneType = prim::Constant()
147+
%keep : bool = prim::Constant[value=0]()
148+
%out : Tensor = aten::linalg_norm(%x, %none, %none, %keep, %none)
149+
return (%out))IR";
150+
auto g = std::make_shared<torch::jit::Graph>();
151+
torch::jit::parseIR(graph, g.get());
152+
auto x = at::randn({5, 5, 5}, {at::kCUDA});
153+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
154+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x});
155+
156+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
157+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x});
158+
159+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
160+
}
161+
162+
TEST(Converters, ATenLinAlgNorm_1D) {
163+
const auto graph = R"IR(
164+
graph(%x : Tensor):
165+
%1 : int = prim::Constant[value=1]()
166+
%none : NoneType = prim::Constant()
167+
%keep : bool = prim::Constant[value=0]()
168+
%dims : int[] = prim::ListConstruct(%1)
169+
%out : Tensor = aten::linalg_norm(%x, %none, %dims, %keep, %none)
170+
return (%out))IR";
171+
auto g = std::make_shared<torch::jit::Graph>();
172+
torch::jit::parseIR(graph, g.get());
173+
174+
auto x = at::randn({5, 5, 5}, {at::kCUDA});
175+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
176+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x});
177+
178+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
179+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x});
180+
181+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
182+
}
183+
184+
TEST(Converters, ATenLinAlgNorm_2D) {
185+
const auto graph = R"IR(
186+
graph(%x : Tensor):
187+
%0 : int = prim::Constant[value=0]()
188+
%1 : int = prim::Constant[value=-1]()
189+
%none : NoneType = prim::Constant()
190+
%keep : bool = prim::Constant[value=1]()
191+
%dims : int[] = prim::ListConstruct(%0, %1)
192+
%float : int = prim::Constant[value=6]()
193+
%out : Tensor = aten::linalg_norm(%x, %none, %dims, %keep, %float)
194+
return (%out))IR";
195+
auto g = std::make_shared<torch::jit::Graph>();
196+
torch::jit::parseIR(graph, g.get());
197+
198+
auto x = at::randn({5, 5, 5}, {at::kCUDA}).to(at::kHalf);
199+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
200+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {x});
201+
202+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
203+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {x});
204+
205+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
206+
}

0 commit comments

Comments
 (0)