Skip to content

Commit a45d105

Browse files
Copilotavik-pal
andcommitted
Fix compilation errors in ReshapeOpCreate refactoring
- Remove .getResult() call on Value at line 1047 - Cast Value to RankedTensorType before getShape() at line 6105 - Remove unused variable post_reshape at line 21655 - Call getDefiningOp() on Value before setShardings at lines 22797, 22809, 23237 - Remove unused variables newLhsType, newRhsType at lines 31319-31322 - Remove unused variable newInputTy at line 31431 Co-authored-by: avik-pal <30564094+avik-pal@users.noreply.github.com>
1 parent cd91947 commit a45d105

File tree

1 file changed

+5
-14
lines changed

1 file changed

+5
-14
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,7 +1097,7 @@ struct ReshapeExtend final
10971097

10981098
// Then create a new extend operation on the reshaped data
10991099
auto newExtendOp = enzymexla::ExtendOp::create(rewriter, op.getLoc(),
1100-
newReshapeOp.getResult(),
1100+
newReshapeOp,
11011101
lhs, rhs, newExtendDim);
11021102

11031103
// Replace the original reshape op with the new extend operation
@@ -6155,7 +6155,7 @@ struct ConcatToBroadcast final
61556155
rewriter, op.getLoc(), op->getOperand(0), reshaped);
61566156

61576157
SmallVector<int64_t> bcast;
6158-
for (auto en : llvm::enumerate(reshapeVal.getType().getShape())) {
6158+
for (auto en : llvm::enumerate(cast<RankedTensorType>(reshapeVal.getType()).getShape())) {
61596159
bcast.push_back(en.index());
61606160
}
61616161
reshaped[op.getDimension()] = op->getNumOperands();
@@ -21705,8 +21705,6 @@ struct SumToConv : public SumToReductionBase<ST, SumToConv<ST>> {
2170521705
if (conv.getType() != pre_reshape) {
2170621706
SmallVector<int64_t> post_shape = llvm::to_vector(pre_reshape.getShape());
2170721707
post_shape[reshapeOffsetDim] -= (lastidx - startidx);
21708-
RankedTensorType post_reshape =
21709-
RankedTensorType::get(post_shape, pre_reshape.getElementType());
2171021708
conv = stablehlo::ReshapeOpCreate(rewriter, input.getLoc(),
2171121709
conv, post_shape);
2171221710
}
@@ -22847,7 +22845,7 @@ struct RecognizeWrap
2284722845
auto reshape = stablehlo::ReshapeOpCreate(
2284822846
rewriter, concat.getLoc(), wrap, newShape);
2284922847
if (auto shard = sdy::getShardingPerValue(rs0)) {
22850-
sdy::setShardings(reshape, shard);
22848+
sdy::setShardings(reshape.getDefiningOp(), shard);
2285122849
}
2285222850
toConcat.push_back(reshape);
2285322851
for (int j = i + 1; j < operands.size(); j++)
@@ -22859,7 +22857,7 @@ struct RecognizeWrap
2285922857
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
2286022858
concat, toConcat, concatDim);
2286122859
if (shard) {
22862-
sdy::setShardings(reshape, shard);
22860+
sdy::setShardings(reshape.getDefiningOp(), shard);
2286322861
}
2286422862
}
2286522863
return success();
@@ -23287,7 +23285,7 @@ struct RecognizeExtend
2328723285
auto reshape = stablehlo::ReshapeOpCreate(
2328823286
rewriter, concat.getLoc(), extend, shape);
2328923287
if (auto shard = sdy::getShardingPerValue(concat)) {
23290-
sdy::setShardings(reshape, shard);
23288+
sdy::setShardings(reshape.getDefiningOp(), shard);
2329123289
}
2329223290
finish(reshape);
2329323291
return success();
@@ -31371,11 +31369,6 @@ struct DotGeneralInsertDimContractionSimplification final
3137131369
}
3137231370

3137331371
// Create reshaped operands. This will be cleaned up later
31374-
auto newLhsType =
31375-
RankedTensorType::get(newLhsShape, lhsType.getElementType());
31376-
auto newRhsType =
31377-
RankedTensorType::get(newRhsShape, rhsType.getElementType());
31378-
3137931372
Value newLhs =
3138031373
stablehlo::ReshapeOpCreate(rewriter, op.getLoc(), lhs, newLhsShape);
3138131374
Value newRhs =
@@ -31483,8 +31476,6 @@ struct DeleteDimsBroadcast final
3148331476
}
3148431477

3148531478
// Create the reshape on the input
31486-
auto newInputTy =
31487-
RankedTensorType::get(newInputShape, bcastInputTy.getElementType());
3148831479
auto reshapeInput = stablehlo::ReshapeOpCreate(
3148931480
rewriter, op.getLoc(), bcastOp.getOperand(), newInputShape);
3149031481

0 commit comments

Comments
 (0)