@@ -1097,7 +1097,7 @@ struct ReshapeExtend final
10971097
10981098 // Then create a new extend operation on the reshaped data
10991099 auto newExtendOp = enzymexla::ExtendOp::create(rewriter, op.getLoc(),
1100- newReshapeOp.getResult() ,
1100+ newReshapeOp,
11011101 lhs, rhs, newExtendDim);
11021102
11031103 // Replace the original reshape op with the new extend operation
@@ -6155,7 +6155,7 @@ struct ConcatToBroadcast final
61556155 rewriter, op.getLoc(), op->getOperand(0), reshaped);
61566156
61576157 SmallVector<int64_t> bcast;
6158- for (auto en : llvm::enumerate(reshapeVal.getType().getShape())) {
6158+ for (auto en : llvm::enumerate(cast<RankedTensorType>( reshapeVal.getType() ).getShape())) {
61596159 bcast.push_back(en.index());
61606160 }
61616161 reshaped[op.getDimension()] = op->getNumOperands();
@@ -21705,8 +21705,6 @@ struct SumToConv : public SumToReductionBase<ST, SumToConv<ST>> {
2170521705 if (conv.getType() != pre_reshape) {
2170621706 SmallVector<int64_t> post_shape = llvm::to_vector(pre_reshape.getShape());
2170721707 post_shape[reshapeOffsetDim] -= (lastidx - startidx);
21708- RankedTensorType post_reshape =
21709- RankedTensorType::get(post_shape, pre_reshape.getElementType());
2171021708 conv = stablehlo::ReshapeOpCreate(rewriter, input.getLoc(),
2171121709 conv, post_shape);
2171221710 }
@@ -22847,7 +22845,7 @@ struct RecognizeWrap
2284722845 auto reshape = stablehlo::ReshapeOpCreate(
2284822846 rewriter, concat.getLoc(), wrap, newShape);
2284922847 if (auto shard = sdy::getShardingPerValue(rs0)) {
22850- sdy::setShardings(reshape, shard);
22848+ sdy::setShardings(reshape.getDefiningOp() , shard);
2285122849 }
2285222850 toConcat.push_back(reshape);
2285322851 for (int j = i + 1; j < operands.size(); j++)
@@ -22859,7 +22857,7 @@ struct RecognizeWrap
2285922857 rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
2286022858 concat, toConcat, concatDim);
2286122859 if (shard) {
22862- sdy::setShardings(reshape, shard);
22860+ sdy::setShardings(reshape.getDefiningOp() , shard);
2286322861 }
2286422862 }
2286522863 return success();
@@ -23287,7 +23285,7 @@ struct RecognizeExtend
2328723285 auto reshape = stablehlo::ReshapeOpCreate(
2328823286 rewriter, concat.getLoc(), extend, shape);
2328923287 if (auto shard = sdy::getShardingPerValue(concat)) {
23290- sdy::setShardings(reshape, shard);
23288+ sdy::setShardings(reshape.getDefiningOp() , shard);
2329123289 }
2329223290 finish(reshape);
2329323291 return success();
@@ -31371,11 +31369,6 @@ struct DotGeneralInsertDimContractionSimplification final
3137131369 }
3137231370
3137331371 // Create reshaped operands. This will be cleaned up later
31374- auto newLhsType =
31375- RankedTensorType::get(newLhsShape, lhsType.getElementType());
31376- auto newRhsType =
31377- RankedTensorType::get(newRhsShape, rhsType.getElementType());
31378-
3137931372 Value newLhs =
3138031373 stablehlo::ReshapeOpCreate(rewriter, op.getLoc(), lhs, newLhsShape);
3138131374 Value newRhs =
@@ -31483,8 +31476,6 @@ struct DeleteDimsBroadcast final
3148331476 }
3148431477
3148531478 // Create the reshape on the input
31486- auto newInputTy =
31487- RankedTensorType::get(newInputShape, bcastInputTy.getElementType());
3148831479 auto reshapeInput = stablehlo::ReshapeOpCreate(
3148931480 rewriter, op.getLoc(), bcastOp.getOperand(), newInputShape);
3149031481
0 commit comments