@@ -10,138 +10,84 @@ namespace converters {
10
10
namespace impl {
11
11
namespace {
12
12
13
+ nvinfer1::ITensor* broadcast (
14
+ ConversionCtx* ctx,
15
+ const torch::jit::Node* n,
16
+ nvinfer1::ITensor* to_broadcast,
17
+ const int nbDims,
18
+ const std::string& tag) {
19
+ auto to_broadcast_nbdims = to_broadcast->getDimensions ().nbDims ;
20
+ TORCHTRT_CHECK (to_broadcast_nbdims <= nbDims, " Cannot broadcast tensor with more dimensions than the target" );
21
+ if (to_broadcast_nbdims == nbDims) {
22
+ return to_broadcast;
23
+ }
24
+ auto shape_layer = ctx->net ->addShape (*to_broadcast);
25
+ TORCHTRT_CHECK (shape_layer, " Unable to create shape layer from node: " << *n);
26
+ shape_layer->setName ((util::node_info (n) + " _shape_" + tag).c_str ());
27
+ auto shape_layer_out = shape_layer->getOutput (0 );
28
+
29
+ auto extra_dims_tensor = torch::ones ({nbDims - to_broadcast_nbdims}, torch::TensorOptions ().dtype (torch::kInt32 ));
30
+ auto extra_dims_itensor = tensor_to_const (ctx, extra_dims_tensor);
31
+
32
+ std::vector<nvinfer1::ITensor*> to_concat = {extra_dims_itensor, shape_layer_out};
33
+ auto concat_layer = ctx->net ->addConcatenation (to_concat.data (), to_concat.size ());
34
+ TORCHTRT_CHECK (concat_layer, " Unable to create concat layer from node: " << *n);
35
+ concat_layer->setName ((util::node_info (n) + " _concat_" + tag).c_str ());
36
+ auto target_shape = concat_layer->getOutput (0 );
37
+
38
+ auto shuffle_layer = ctx->net ->addShuffle (*to_broadcast);
39
+ TORCHTRT_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
40
+ shuffle_layer->setName ((util::node_info (n) + " _shuffle_" + tag).c_str ());
41
+ shuffle_layer->setInput (1 , *target_shape);
42
+ auto output = shuffle_layer->getOutput (0 );
43
+ LOG_DEBUG (
44
+ " Broadcast " << tag << " to shape: " << output->getDimensions () << " from " << to_broadcast->getDimensions ());
45
+ return output;
46
+ }
47
+
13
48
auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern({
14
49
R"SIG( aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta,
15
50
float eps, bool cudnn_enabled) -> (Tensor))SIG" ,
16
51
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
17
- auto input = args[0 ].ITensor (); // assumes non-static input Tensor
18
- auto orig_shape = input->getDimensions ();
19
- auto shape = util::toVec (orig_shape);
20
-
21
- /* Layer_Norm normalizes over last N dimensions.
22
- normalizaed_shape could be (C,H,W), (H,W), or (W). */
23
- // This could be an IntList or ITensorList. We only need the size of this list.
24
- auto normalized_shape = args[1 ].IValue ()->toList ();
25
-
26
- // Unwrap eps.
27
- auto eps = args[4 ].unwrapToDouble ();
28
-
29
- LOG_DEBUG (" cudnn disregarded" );
30
-
31
- // Set up axis_ask for E[x].
32
- uint32_t axis_mask = 0 ;
33
- for (size_t i = 0 ; i < normalized_shape.size (); i++) {
34
- axis_mask |= 1 << (shape.size () - i - 1 );
52
+ auto input = args[0 ].ITensorOrFreeze (ctx);
53
+ auto input_shape = input->getDimensions ();
54
+ auto input_shape_vec = util::toVec (input_shape);
55
+ auto normalized_shape = args[1 ].unwrapToIntList ();
56
+ auto normalized_shape_vec = util::toVec (util::toDims (normalized_shape));
57
+ auto axis = input_shape_vec.size () - normalized_shape_vec.size ();
58
+ uint32_t axes_mask = 0 ;
59
+ for (size_t i = axis; i < input_shape_vec.size (); i++) {
60
+ axes_mask |= 1 << i;
35
61
}
36
- LOG_DEBUG (" Axis Mask for E[x]" << std::bitset<32 >(axis_mask));
37
-
38
- // E[x]
39
- auto mean_expected = ctx->net ->addReduce (*input, nvinfer1::ReduceOperation::kAVG , axis_mask, true );
40
- TORCHTRT_CHECK (mean_expected, " Unable to create mean_expected from node: " << *n);
41
- mean_expected->setName ((util::node_info (n) + " _mean_expected" ).c_str ());
42
- auto mean_expected_out = mean_expected->getOutput (0 );
43
-
44
- // X-E[x]
45
- auto sub = add_elementwise (
46
- ctx, nvinfer1::ElementWiseOperation::kSUB , input, mean_expected_out, (util::node_info (n) + " _sub" ).c_str ());
47
- TORCHTRT_CHECK (sub, " Unable to create Sub layer from node: " << *n);
48
- sub->setName ((util::node_info (n) + " _sub" ).c_str ());
49
- auto xsubmean_out = sub->getOutput (0 );
50
-
51
- // Variance = mean(pow(xsubmean,2))
52
- float pow_scalar = 2 ;
53
- auto exponent = tensor_to_const (ctx, torch::tensor ({pow_scalar}));
54
- auto pow = add_elementwise (
55
- ctx, nvinfer1::ElementWiseOperation::kPOW , xsubmean_out, exponent, (util::node_info (n) + " _pow" ).c_str ());
56
- TORCHTRT_CHECK (pow, " Unable to create Pow layer from node: " << *n);
57
- pow->setName ((util::node_info (n) + " _pow" ).c_str ());
58
- auto pow_out = pow->getOutput (0 );
59
-
60
- auto mean_var = ctx->net ->addReduce (*pow_out, nvinfer1::ReduceOperation::kAVG , axis_mask, true );
61
- TORCHTRT_CHECK (mean_var, " Unable to create mean_var from node: " << *n);
62
- mean_var->setName ((util::node_info (n) + " _mean_var" ).c_str ());
63
- auto mean_var_out = mean_var->getOutput (0 );
64
-
65
- // Variance + eps
66
- auto eps_tensor = tensor_to_const (ctx, torch::tensor ({eps}));
67
- auto add = add_elementwise (
68
- ctx, nvinfer1::ElementWiseOperation::kSUM , mean_var_out, eps_tensor, (util::node_info (n) + " _add" ).c_str ());
69
- TORCHTRT_CHECK (add, " Unable to create Add layer from node: " << *n);
70
- add->setName ((util::node_info (n) + " _add" ).c_str ());
71
- auto add_out = add->getOutput (0 );
72
62
73
- // SQRT((Var + eps))
74
- auto sqrt = ctx->net ->addUnary (*add_out, nvinfer1::UnaryOperation::kSQRT );
75
- TORCHTRT_CHECK (sqrt, " Unable to create unary(sqrt) from node: " << *n);
76
- sqrt->setName ((util::node_info (n) + " _sqrt" ).c_str ());
77
- auto sqrt_out = sqrt->getOutput (0 );
78
-
79
- // (x - E[x]) / sqrt((var + eps))
80
- auto div = add_elementwise (
81
- ctx, nvinfer1::ElementWiseOperation::kDIV , xsubmean_out, sqrt_out, (util::node_info (n) + " _div" ).c_str ());
82
- TORCHTRT_CHECK (div, " Unable to create div layer from node: " << *n);
83
- div->setName ((util::node_info (n) + " _div" ).c_str ());
84
- auto div_out = div->getOutput (0 );
85
-
86
- if (!args[2 ].IValue ()->isTensor () && !args[3 ].IValue ()->isTensor ()) {
87
- ctx->AssociateValueAndTensor (n->outputs ()[0 ], div_out);
88
- return true ;
89
- }
90
-
91
- // Remove batch dimension from input shape for expand_size, which will
92
- // be used to create weights for addScaleNd later.
93
- auto expand_size = shape;
94
- expand_size.erase (expand_size.begin (), expand_size.begin () + 1 );
95
-
96
- // Set up gamma_weights and beta_weights from gamma_expand and
97
- // beta_expand.
98
- auto gamma_weights = Weights (ctx, at::ones (expand_size));
99
- auto beta_weights = Weights (ctx, at::zeros (expand_size));
100
-
101
- if (args[2 ].IValue ()->isTensor ()) {
102
- torch::Tensor gamma;
103
- gamma = args[2 ].unwrapToTensor ();
104
- auto gamma_expand = gamma.expand (expand_size);
105
- gamma_weights = Weights (ctx, gamma_expand);
63
+ nvinfer1::ITensor* gamma = nullptr ;
64
+ if (args[2 ].IValue ()->isNone ()) {
65
+ auto gamma_torch_tensor = torch::ones (input_shape_vec, torch::TensorOptions ().dtype (torch::kFloat32 ));
66
+ gamma = tensor_to_const (ctx, gamma_torch_tensor);
106
67
} else {
107
- gamma_weights = Weights (ctx, at::ones (expand_size));
68
+ gamma = args[2 ].ITensorOrFreeze (ctx);
69
+ gamma = broadcast (ctx, n, gamma, input_shape_vec.size (), " gamma" );
108
70
}
109
71
110
- if (args[3 ].IValue ()->isTensor ()) {
111
- torch::Tensor beta;
112
- beta = args[3 ].unwrapToTensor ();
113
- auto beta_expand = beta.expand (expand_size);
114
- beta_weights = Weights (ctx, beta_expand);
72
+ nvinfer1::ITensor* beta = nullptr ;
73
+ if (args[3 ].IValue ()->isNone ()) {
74
+ auto beta_torch_tensor = torch::zeros (input_shape_vec, torch::TensorOptions ().dtype (torch::kFloat32 ));
75
+ beta = tensor_to_const (ctx, beta_torch_tensor);
115
76
} else {
116
- beta_weights = Weights (ctx, at::zeros (expand_size));
77
+ beta = args[3 ].ITensorOrFreeze (ctx);
78
+ beta = broadcast (ctx, n, beta, input_shape_vec.size (), " beta" );
117
79
}
118
80
119
- auto power = Weights (ctx, at::ones (expand_size));
120
-
121
- auto gamma_tensor = ctx->net ->addConstant (gamma_weights.shape , gamma_weights.data )->getOutput (0 );
122
- auto scale_l = add_elementwise (
123
- ctx, nvinfer1::ElementWiseOperation::kPROD , div_out, gamma_tensor, (util::node_info (n) + " _scale" ).c_str ());
124
-
125
- auto beta_tensor = ctx->net ->addConstant (beta_weights.shape , beta_weights.data )->getOutput (0 );
126
- auto shift_l = add_elementwise (
127
- ctx,
128
- nvinfer1::ElementWiseOperation::kSUM ,
129
- scale_l->getOutput (0 ),
130
- beta_tensor,
131
- (util::node_info (n) + " _shift" ).c_str ());
132
-
133
- auto power_tensor = ctx->net ->addConstant (power.shape , power.data )->getOutput (0 );
134
- auto power_l = add_elementwise (
135
- ctx,
136
- nvinfer1::ElementWiseOperation::kPOW ,
137
- shift_l->getOutput (0 ),
138
- power_tensor,
139
- (util::node_info (n) + " _power" ).c_str ());
81
+ auto eps = args[4 ].unwrapToDouble ();
140
82
141
- power_l->setName ((util::node_info (n) + " _scale_nd" ).c_str ());
142
- auto power_l_out = power_l->getOutput (0 );
83
+ auto normalize_layer = ctx->net ->addNormalization (*input, *gamma, *beta, axes_mask);
84
+ TORCHTRT_CHECK (normalize_layer, " Unable to create layer_norm from node: " << *n);
85
+ normalize_layer->setName (util::node_info (n).c_str ());
86
+ normalize_layer->setEpsilon (eps);
87
+ normalize_layer->setComputePrecision (nvinfer1::DataType::kFLOAT );
88
+ auto normalized = normalize_layer->getOutput (0 );
143
89
144
- ctx->AssociateValueAndTensor (n->outputs ()[0 ], power_l_out );
90
+ ctx->AssociateValueAndTensor (n->outputs ()[0 ], normalized );
145
91
return true ;
146
92
}});
147
93
0 commit comments