@@ -1044,7 +1044,7 @@ struct ReshapeExtend final
10441044
10451045 // Then create a new extend operation on the reshaped data
10461046 auto newExtendOp = enzymexla::ExtendOp::create(rewriter, op.getLoc(),
1047- newReshapeOp.getResult() ,
1047+ newReshapeOp,
10481048 lhs, rhs, newExtendDim);
10491049
10501050 // Replace the original reshape op with the new extend operation
@@ -6102,7 +6102,7 @@ struct ConcatToBroadcast final
61026102 rewriter, op.getLoc(), op->getOperand(0), reshaped);
61036103
61046104 SmallVector<int64_t> bcast;
6105- for (auto en : llvm::enumerate(reshapeVal.getType().getShape())) {
6105+ for (auto en : llvm::enumerate(cast<RankedTensorType>( reshapeVal.getType() ).getShape())) {
61066106 bcast.push_back(en.index());
61076107 }
61086108 reshaped[op.getDimension()] = op->getNumOperands();
@@ -21652,8 +21652,6 @@ struct SumToConv : public SumToReductionBase<ST, SumToConv<ST>> {
2165221652 if (conv.getType() != pre_reshape) {
2165321653 SmallVector<int64_t> post_shape = llvm::to_vector(pre_reshape.getShape());
2165421654 post_shape[reshapeOffsetDim] -= (lastidx - startidx);
21655- RankedTensorType post_reshape =
21656- RankedTensorType::get(post_shape, pre_reshape.getElementType());
2165721655 conv = stablehlo::ReshapeOpCreate(rewriter, input.getLoc(),
2165821656 conv, post_shape);
2165921657 }
@@ -22794,7 +22792,7 @@ struct RecognizeWrap
2279422792 auto reshape = stablehlo::ReshapeOpCreate(
2279522793 rewriter, concat.getLoc(), wrap, newShape);
2279622794 if (auto shard = sdy::getShardingPerValue(rs0)) {
22797- sdy::setShardings(reshape, shard);
22795+ sdy::setShardings(reshape.getDefiningOp() , shard);
2279822796 }
2279922797 toConcat.push_back(reshape);
2280022798 for (int j = i + 1; j < operands.size(); j++)
@@ -22806,7 +22804,7 @@ struct RecognizeWrap
2280622804 rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
2280722805 concat, toConcat, concatDim);
2280822806 if (shard) {
22809- sdy::setShardings(reshape, shard);
22807+ sdy::setShardings(reshape.getDefiningOp() , shard);
2281022808 }
2281122809 }
2281222810 return success();
@@ -23234,7 +23232,7 @@ struct RecognizeExtend
2323423232 auto reshape = stablehlo::ReshapeOpCreate(
2323523233 rewriter, concat.getLoc(), extend, shape);
2323623234 if (auto shard = sdy::getShardingPerValue(concat)) {
23237- sdy::setShardings(reshape, shard);
23235+ sdy::setShardings(reshape.getDefiningOp() , shard);
2323823236 }
2323923237 finish(reshape);
2324023238 return success();
@@ -31316,11 +31314,6 @@ struct DotGeneralInsertDimContractionSimplification final
3131631314 }
3131731315
3131831316 // Create reshaped operands. This will be cleaned up later
31319- auto newLhsType =
31320- RankedTensorType::get(newLhsShape, lhsType.getElementType());
31321- auto newRhsType =
31322- RankedTensorType::get(newRhsShape, rhsType.getElementType());
31323-
3132431317 Value newLhs =
3132531318 stablehlo::ReshapeOpCreate(rewriter, op.getLoc(), lhs, newLhsShape);
3132631319 Value newRhs =
@@ -31428,8 +31421,6 @@ struct DeleteDimsBroadcast final
3142831421 }
3142931422
3143031423 // Create the reshape on the input
31431- auto newInputTy =
31432- RankedTensorType::get(newInputShape, bcastInputTy.getElementType());
3143331424 auto reshapeInput = stablehlo::ReshapeOpCreate(
3143431425 rewriter, op.getLoc(), bcastOp.getOperand(), newInputShape);
3143531426
0 commit comments