@@ -161,31 +161,41 @@ void RunItersTransformInstr(const std::shared_ptr<ItersTransformInstr>& instr,
161
161
162
162
void RunAxisTransformInstr (const std::shared_ptr<AxisTransformInstr>& instr,
163
163
FusionInterpreter* interpreter) {
164
- auto substitute_dimexpr_for_shape = [&](std::vector<symbol::DimExpr>& shape) {
165
- for (auto & dim_expr : shape) {
166
- if (dim_expr.isa <std::int64_t >()) continue ;
167
- symbol::DimExpr origin_dim_expr = dim_expr;
168
- while (true ) {
169
- dim_expr = symbol::SubstituteDimExpr (
170
- dim_expr, interpreter->substitute_dimexpr_map );
171
- if (dim_expr == origin_dim_expr || dim_expr.isa <std::int64_t >()) break ;
172
- origin_dim_expr = dim_expr;
173
- }
174
- }
175
- };
176
- auto substitute_dimexpr_for_transform =
177
- adt::match{[&](const AppendAxisTransformPtr& transform) {
178
- substitute_dimexpr_for_shape (transform->shape );
179
- },
180
- [&](const ReshapeTransformPtr& transform) {
181
- substitute_dimexpr_for_shape (transform->in_shape );
182
- substitute_dimexpr_for_shape (transform->out_shape );
183
- },
184
- [&](const auto & transform) {}};
164
+ auto substitute_dimexpr_for_shape =
165
+ [&](const std::vector<symbol::DimExpr>& shape) {
166
+ std::vector<symbol::DimExpr> result;
167
+ for (const auto & dim_expr : shape) {
168
+ symbol::DimExpr substituted = dim_expr;
169
+ while (true ) {
170
+ if (substituted.isa <std::int64_t >()) break ;
171
+ auto tmp_substituted = symbol::SubstituteDimExpr (
172
+ substituted, interpreter->substitute_dimexpr_map );
173
+ if (tmp_substituted == substituted) break ;
174
+ substituted = tmp_substituted;
175
+ }
176
+ result.emplace_back (substituted);
177
+ }
178
+ return result;
179
+ };
180
+ auto substitute_dimexpr_for_transform = adt::match{
181
+ [&](const AppendAxisTransformPtr& trans) -> AxisTransform {
182
+ auto substituted_shape = substitute_dimexpr_for_shape (trans->shape );
183
+ return std::make_shared<AppendAxisTransform>(trans->axis ,
184
+ substituted_shape);
185
+ },
186
+ [&](const ReshapeTransformPtr& trans) -> AxisTransform {
187
+ auto substituted_in_shape =
188
+ substitute_dimexpr_for_shape (trans->in_shape );
189
+ auto substituted_out_shape =
190
+ substitute_dimexpr_for_shape (trans->out_shape );
191
+ return std::make_shared<ReshapeTransform>(substituted_in_shape,
192
+ substituted_out_shape);
193
+ },
194
+ [&](const auto & trans) -> AxisTransform { return trans; }};
185
195
auto axis_transform = [&](ir::Expr op_expr) -> ir::Expr {
186
196
for (auto trans : instr->axis_transform_route_ ) {
187
- std::visit (substitute_dimexpr_for_transform, trans);
188
- op_expr = std::visit (ApplyAxisTransform (op_expr), trans );
197
+ auto new_trans = std::visit (substitute_dimexpr_for_transform, trans);
198
+ op_expr = std::visit (ApplyAxisTransform (op_expr), new_trans );
189
199
}
190
200
return op_expr;
191
201
};
0 commit comments