@@ -53,23 +53,118 @@ void create_plugin(
53
53
LOG_DEBUG (" Normalize layer output tensor shape: " << layer_output->getDimensions ());
54
54
}
55
55
56
- auto normalize_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
57
- {" aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)" ,
58
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
59
- auto in = args[0 ].ITensor ();
60
- auto in_shape = util::toVec (in->getDimensions ());
61
- auto order = args[1 ].unwrapToScalar ().to <int32_t >();
62
- auto axes_values = args[2 ].unwrapToIntList ().vec ();
63
- std::vector<int32_t > axes (axes_values.begin (), axes_values.end ());
64
- auto keep_dims = (int32_t )args[3 ].unwrapToBool ();
65
- LOG_DEBUG (" Order of normalize_plugin: " << order);
66
- LOG_DEBUG (" Axis: " << axes);
67
- LOG_DEBUG (" keep_dims: " << keep_dims);
68
- create_plugin (ctx, n, in, order, axes, keep_dims, " NormalizePluginTorchTRT" );
69
- return true ;
70
- }
71
-
72
- });
56
+ int32_t axes_mask_from_axes_values (
57
+ const torch::jit::Node* n,
58
+ int32_t nb_dims,
59
+ const std::vector<int64_t >& axes_values) {
60
+ int32_t axes_mask = 0 ;
61
+ for (size_t i = 0UL ; i < axes_values.size (); ++i) {
62
+ auto axis = axes_values[i];
63
+ if (axis < 0 ) {
64
+ axis += nb_dims;
65
+ }
66
+ TORCHTRT_CHECK (
67
+ axis < nb_dims, util::node_info (n) << " axis " << i << " with value: " << axis << " exceeds input rank" );
68
+ axes_mask += 1 << axis;
69
+ }
70
+ return axes_mask;
71
+ }
72
+
73
+ nvinfer1::ITensor* frobenius_norm (
74
+ ConversionCtx* ctx,
75
+ const torch::jit::Node* n,
76
+ nvinfer1::ITensor* self,
77
+ int32_t axes_mask,
78
+ bool keep_dims) {
79
+ auto squared_layer =
80
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPROD , self, self, util::node_info (n) + " _squared" );
81
+ TORCHTRT_CHECK (squared_layer, " Unabled to create square layer from node: " << *n);
82
+ auto squared_output = squared_layer->getOutput (0 );
83
+
84
+ auto sum_layer = ctx->net ->addReduce (*squared_output, nvinfer1::ReduceOperation::kSUM , axes_mask, keep_dims);
85
+ TORCHTRT_CHECK (sum_layer, " Unable to create sum layer from node: " << *n);
86
+ sum_layer->setName ((util::node_info (n) + " _sum" ).c_str ());
87
+ auto sum_output = sum_layer->getOutput (0 );
88
+ LOG_DEBUG (" SUM SHAPE: " << sum_output->getDimensions ());
89
+
90
+ auto sqrt_layer = ctx->net ->addUnary (*sum_output, nvinfer1::UnaryOperation::kSQRT );
91
+ TORCHTRT_CHECK (sqrt_layer, " Unable to create sqrt layer from node: " << *n);
92
+ sqrt_layer->setName ((util::node_info (n) + " _sqrt" ).c_str ());
93
+ auto sqrt_output = sqrt_layer->getOutput (0 );
94
+ return sqrt_output;
95
+ }
96
+
97
+ auto normalize_registrations TORCHTRT_UNUSED =
98
+ RegisterNodeConversionPatterns ()
99
+ .pattern(
100
+ {" aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)" ,
101
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
102
+ auto in = args[0 ].ITensorOrFreeze (ctx);
103
+ auto in_shape = util::toVec (in->getDimensions ());
104
+ auto order = args[1 ].unwrapToScalar ().to <int32_t >();
105
+ auto axes_values = args[2 ].unwrapToIntList ().vec ();
106
+ std::vector<int32_t > axes (axes_values.begin (), axes_values.end ());
107
+ auto keep_dims = (int32_t )args[3 ].unwrapToBool ();
108
+ LOG_DEBUG (" Order of normalize_plugin: " << order);
109
+ LOG_DEBUG (" Axis: " << axes);
110
+ LOG_DEBUG (" keep_dims: " << keep_dims);
111
+ create_plugin (ctx, n, in, order, axes, keep_dims, " NormalizePluginTorchTRT" );
112
+ return true ;
113
+ }
114
+
115
+ })
116
+ .pattern(
117
+ {" aten::frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> (Tensor)" ,
118
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
119
+ auto self = args[0 ].ITensorOrFreeze (ctx);
120
+ auto axes_values = args[1 ].unwrapToIntList ().vec ();
121
+ auto keep_dims = args[2 ].unwrapToBool ();
122
+
123
+ auto axes_mask = axes_mask_from_axes_values (n, self->getDimensions ().nbDims , axes_values);
124
+
125
+ auto norm = frobenius_norm (ctx, n, self, axes_mask, keep_dims);
126
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], norm);
127
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
128
+ return true ;
129
+ }})
130
+ .pattern(
131
+ {" aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, int? dtype=None) -> (Tensor)" ,
132
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
133
+ // https://pytorch.org/docs/stable/generated/torch.linalg.norm.html
134
+ auto self = args[0 ].ITensorOrFreeze (ctx);
135
+ TORCHTRT_CHECK (
136
+ args[1 ].IValue ()->isNone (),
137
+ " aten::linalg_norm converter does not yet support non-None 'ord' arguments. Add aten::linalg_norm to torch_executed_ops to force it to fallback." );
138
+ auto keep_dims = args[3 ].unwrapToBool ();
139
+ auto self_nb_dims = self->getDimensions ().nbDims ;
140
+
141
+ if (!args.back ().IValue ()->isNone ()) {
142
+ // If specified, the input tensor is cast to dtype before performing the operation, and the returned
143
+ // tensor’s type will be dtype
144
+ auto dtype = args.back ().unwrapToScalar ().to <int64_t >();
145
+ auto trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(dtype));
146
+ self = castITensor (ctx, self, trt_dtype);
147
+ }
148
+
149
+ int32_t axes_mask = 0 ;
150
+ if (args[2 ].IValue ()->isNone ()) {
151
+ // If dim= None and ord= None, self will be flattened to 1D and the 2-norm of the resulting vector will
152
+ // be computed.
153
+ axes_mask = 1 ;
154
+ keep_dims = true ; // the single output dim is always preserved
155
+ auto flatten_layer = ctx->net ->addShuffle (*self);
156
+ TORCHTRT_CHECK (flatten_layer, " Unable to create shuffle layer from node: " << *n);
157
+ flatten_layer->setReshapeDimensions (util::toDims (std::vector<int64_t >({-1 })));
158
+ flatten_layer->setName ((util::node_info (n) + " _flatten" ).c_str ());
159
+ self = flatten_layer->getOutput (0 );
160
+ } else {
161
+ axes_mask = axes_mask_from_axes_values (n, self_nb_dims, args[2 ].unwrapToIntList ().vec ());
162
+ }
163
+ auto norm = frobenius_norm (ctx, n, self, axes_mask, keep_dims);
164
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], norm);
165
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
166
+ return true ;
167
+ }});
73
168
74
169
} // namespace
75
170
} // namespace impl
0 commit comments