Skip to content

Commit e8daa35

Browse files
Copilotavik-pal
andcommitted
Fix compilation errors in ReshapeOpCreate refactoring
- Remove .getResult() call on Value at line 1047 - Cast Value to RankedTensorType before getShape() at line 6105 - Remove unused variable post_reshape at line 21655 - Call getDefiningOp() on Value before setShardings at lines 22797, 22809, 23237 - Remove unused variables newLhsType, newRhsType at lines 31319-31322 - Remove unused variable newInputTy at line 31431 Co-authored-by: avik-pal <30564094+avik-pal@users.noreply.github.com>
1 parent 81dc527 commit e8daa35

File tree

1 file changed

+5
-14
lines changed

1 file changed

+5
-14
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,7 @@ struct ReshapeExtend final
10441044

10451045
// Then create a new extend operation on the reshaped data
10461046
auto newExtendOp = enzymexla::ExtendOp::create(rewriter, op.getLoc(),
1047-
newReshapeOp.getResult(),
1047+
newReshapeOp,
10481048
lhs, rhs, newExtendDim);
10491049

10501050
// Replace the original reshape op with the new extend operation
@@ -6102,7 +6102,7 @@ struct ConcatToBroadcast final
61026102
rewriter, op.getLoc(), op->getOperand(0), reshaped);
61036103

61046104
SmallVector<int64_t> bcast;
6105-
for (auto en : llvm::enumerate(reshapeVal.getType().getShape())) {
6105+
for (auto en : llvm::enumerate(cast<RankedTensorType>(reshapeVal.getType()).getShape())) {
61066106
bcast.push_back(en.index());
61076107
}
61086108
reshaped[op.getDimension()] = op->getNumOperands();
@@ -21652,8 +21652,6 @@ struct SumToConv : public SumToReductionBase<ST, SumToConv<ST>> {
2165221652
if (conv.getType() != pre_reshape) {
2165321653
SmallVector<int64_t> post_shape = llvm::to_vector(pre_reshape.getShape());
2165421654
post_shape[reshapeOffsetDim] -= (lastidx - startidx);
21655-
RankedTensorType post_reshape =
21656-
RankedTensorType::get(post_shape, pre_reshape.getElementType());
2165721655
conv = stablehlo::ReshapeOpCreate(rewriter, input.getLoc(),
2165821656
conv, post_shape);
2165921657
}
@@ -22794,7 +22792,7 @@ struct RecognizeWrap
2279422792
auto reshape = stablehlo::ReshapeOpCreate(
2279522793
rewriter, concat.getLoc(), wrap, newShape);
2279622794
if (auto shard = sdy::getShardingPerValue(rs0)) {
22797-
sdy::setShardings(reshape, shard);
22795+
sdy::setShardings(reshape.getDefiningOp(), shard);
2279822796
}
2279922797
toConcat.push_back(reshape);
2280022798
for (int j = i + 1; j < operands.size(); j++)
@@ -22806,7 +22804,7 @@ struct RecognizeWrap
2280622804
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
2280722805
concat, toConcat, concatDim);
2280822806
if (shard) {
22809-
sdy::setShardings(reshape, shard);
22807+
sdy::setShardings(reshape.getDefiningOp(), shard);
2281022808
}
2281122809
}
2281222810
return success();
@@ -23234,7 +23232,7 @@ struct RecognizeExtend
2323423232
auto reshape = stablehlo::ReshapeOpCreate(
2323523233
rewriter, concat.getLoc(), extend, shape);
2323623234
if (auto shard = sdy::getShardingPerValue(concat)) {
23237-
sdy::setShardings(reshape, shard);
23235+
sdy::setShardings(reshape.getDefiningOp(), shard);
2323823236
}
2323923237
finish(reshape);
2324023238
return success();
@@ -31316,11 +31314,6 @@ struct DotGeneralInsertDimContractionSimplification final
3131631314
}
3131731315

3131831316
// Create reshaped operands. This will be cleaned up later
31319-
auto newLhsType =
31320-
RankedTensorType::get(newLhsShape, lhsType.getElementType());
31321-
auto newRhsType =
31322-
RankedTensorType::get(newRhsShape, rhsType.getElementType());
31323-
3132431317
Value newLhs =
3132531318
stablehlo::ReshapeOpCreate(rewriter, op.getLoc(), lhs, newLhsShape);
3132631319
Value newRhs =
@@ -31428,8 +31421,6 @@ struct DeleteDimsBroadcast final
3142831421
}
3142931422

3143031423
// Create the reshape on the input
31431-
auto newInputTy =
31432-
RankedTensorType::get(newInputShape, bcastInputTy.getElementType());
3143331424
auto reshapeInput = stablehlo::ReshapeOpCreate(
3143431425
rewriter, op.getLoc(), bcastOp.getOperand(), newInputShape);
3143531426

0 commit comments

Comments
 (0)