@@ -53,6 +53,47 @@ void create_plugin(
53
53
LOG_DEBUG (" Normalize layer output tensor shape: " << layer_output->getDimensions ());
54
54
}
55
55
56
+ int32_t axes_mask_from_axes_values (
57
+ const torch::jit::Node* n,
58
+ int32_t nb_dims,
59
+ const std::vector<int64_t >& axes_values) {
60
+ int32_t axes_mask = 0 ;
61
+ for (size_t i = 0UL ; i < axes_values.size (); ++i) {
62
+ auto axis = axes_values[i];
63
+ if (axis < 0 ) {
64
+ axis += nb_dims;
65
+ }
66
+ TORCHTRT_CHECK (
67
+ axis < nb_dims, util::node_info (n) << " axis " << i << " with value: " << axis << " exceeds input rank" );
68
+ axes_mask += 1 << axis;
69
+ }
70
+ return axes_mask;
71
+ }
72
+
73
+ nvinfer1::ITensor* frobenius_norm (
74
+ ConversionCtx* ctx,
75
+ const torch::jit::Node* n,
76
+ nvinfer1::ITensor* self,
77
+ int32_t axes_mask,
78
+ bool keep_dims) {
79
+ auto squared_layer =
80
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPROD , self, self, util::node_info (n) + " _squared" );
81
+ TORCHTRT_CHECK (squared_layer, " Unabled to create square layer from node: " << *n);
82
+ auto squared_output = squared_layer->getOutput (0 );
83
+
84
+ auto sum_layer = ctx->net ->addReduce (*squared_output, nvinfer1::ReduceOperation::kSUM , axes_mask, keep_dims);
85
+ TORCHTRT_CHECK (sum_layer, " Unable to create sum layer from node: " << *n);
86
+ sum_layer->setName ((util::node_info (n) + " _sum" ).c_str ());
87
+ auto sum_output = sum_layer->getOutput (0 );
88
+ LOG_DEBUG (" SUM SHAPE: " << sum_output->getDimensions ());
89
+
90
+ auto sqrt_layer = ctx->net ->addUnary (*sum_output, nvinfer1::UnaryOperation::kSQRT );
91
+ TORCHTRT_CHECK (sqrt_layer, " Unable to create sqrt layer from node: " << *n);
92
+ sqrt_layer->setName ((util::node_info (n) + " _sqrt" ).c_str ());
93
+ auto sqrt_output = sqrt_layer->getOutput (0 );
94
+ return sqrt_output;
95
+ }
96
+
56
97
auto normalize_registrations TORCHTRT_UNUSED =
57
98
RegisterNodeConversionPatterns ()
58
99
.pattern(
@@ -79,37 +120,48 @@ auto normalize_registrations TORCHTRT_UNUSED =
79
120
auto axes_values = args[1 ].unwrapToIntList ().vec ();
80
121
auto keep_dims = args[2 ].unwrapToBool ();
81
122
82
- int32_t axes_mask = 0 ;
83
- auto self_nb_dims = self->getDimensions ().nbDims ;
84
- for (size_t i = 0UL ; i < axes_values.size (); ++i) {
85
- auto axis = axes_values[i];
86
- if (axis < 0 ) {
87
- axis += self_nb_dims;
88
- }
89
- TORCHTRT_CHECK (
90
- axis < self_nb_dims,
91
- " aten::frobenius_norm axis: " << i << " with value: " << axis << " exceeds input rank" );
92
- axes_mask += 1 << axis;
93
- }
123
+ auto axes_mask = axes_mask_from_axes_values (n, self->getDimensions ().nbDims , axes_values);
94
124
95
- auto squared_layer = add_elementwise (
96
- ctx, nvinfer1::ElementWiseOperation:: kPROD , self, self, util::node_info (n) + " _squared " );
97
- TORCHTRT_CHECK (squared_layer, " Unabled to create square layer from node : " << *n );
98
- auto squared_output = squared_layer-> getOutput ( 0 ) ;
99
-
100
- auto sum_layer =
101
- ctx-> net -> addReduce (*squared_output, nvinfer1::ReduceOperation:: kSUM , axes_mask, keep_dims);
102
- TORCHTRT_CHECK (sum_layer, " Unable to create sum layer from node: " << *n);
103
- sum_layer-> setName (( util::node_info (n) + " _sum " ). c_str ());
104
- auto sum_output = sum_layer-> getOutput ( 0 );
105
-
106
- auto sqrt_layer = ctx-> net -> addUnary (*sum_output, nvinfer1::UnaryOperation:: kSQRT );
107
- TORCHTRT_CHECK (sqrt_layer, " Unable to create sqrt layer from node: " << *n );
108
- sqrt_layer-> setName (( util::node_info (n) + " _sqrt " ). c_str () );
109
- auto sqrt_output = sqrt_layer-> getOutput ( 0 ) ;
125
+ auto norm = frobenius_norm (ctx, n, self, axes_mask, keep_dims);
126
+ auto out = ctx-> AssociateValueAndTensor (n-> outputs ()[ 0 ], norm );
127
+ LOG_DEBUG ( " Output tensor shape : " << out-> getDimensions () );
128
+ return true ;
129
+ }})
130
+ .pattern(
131
+ { " aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, int? dtype=None) -> (Tensor) " ,
132
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
133
+ // https://pytorch.org/docs/stable/generated/torch.linalg.norm.html
134
+ auto self = args[ 0 ]. ITensorOrFreeze (ctx );
135
+ TORCHTRT_CHECK (
136
+ args[ 1 ]. IValue ()-> isNone (),
137
+ " aten::linalg_norm converter does not yet support non-None 'ord' arguments. Add aten::linalg_norm to torch_executed_ops to force it to fallback. " );
138
+ auto keep_dims = args[ 3 ]. unwrapToBool ( );
139
+ auto self_nb_dims = self-> getDimensions (). nbDims ;
110
140
111
- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], sqrt_layer->getOutput (0 ));
141
+ if (!args.back ().IValue ()->isNone ()) {
142
+ // If specified, the input tensor is cast to dtype before performing the operation, and the returned
143
+ // tensor’s type will be dtype
144
+ auto dtype = args.back ().unwrapToScalar ().to <int64_t >();
145
+ auto trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(dtype));
146
+ self = castITensor (ctx, self, trt_dtype);
147
+ }
112
148
149
+ int32_t axes_mask = 0 ;
150
+ if (args[2 ].IValue ()->isNone ()) {
151
+ // If dim= None and ord= None, self will be flattened to 1D and the 2-norm of the resulting vector will
152
+ // be computed.
153
+ axes_mask = 1 ;
154
+ keep_dims = true ; // the single output dim is always preserved
155
+ auto flatten_layer = ctx->net ->addShuffle (*self);
156
+ TORCHTRT_CHECK (flatten_layer, " Unable to create shuffle layer from node: " << *n);
157
+ flatten_layer->setReshapeDimensions (util::toDims (std::vector<int64_t >({-1 })));
158
+ flatten_layer->setName ((util::node_info (n) + " _flatten" ).c_str ());
159
+ self = flatten_layer->getOutput (0 );
160
+ } else {
161
+ axes_mask = axes_mask_from_axes_values (n, self_nb_dims, args[2 ].unwrapToIntList ().vec ());
162
+ }
163
+ auto norm = frobenius_norm (ctx, n, self, axes_mask, keep_dims);
164
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], norm);
113
165
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
114
166
return true ;
115
167
}});
0 commit comments