Skip to content

Commit 33cf3f6

Browse files
Copilotwsmoses
andcommitted
Simplify reshape logic using GetDimensionSizeOp uniformly
- Remove conditional branches for static vs dynamic reshape - Always use GetDimensionSizeOp for all dimensions (optimized away for static shapes) - Remove explicit type parameter from ConstantOp::create (type deduced automatically) - Simplifies code and relies on compiler optimizations Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
1 parent 6af39ce commit 33cf3f6

File tree

1 file changed

+27
-73
lines changed

1 file changed

+27
-73
lines changed

src/enzyme_ad/jax/Utils.cpp

Lines changed: 27 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,40 +1391,17 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
13911391
auto intermediateType = RankedTensorType::get(intermediateShape, targetElemType);
13921392
res = builder.create<stablehlo::BitcastConvertOp>(loc, intermediateType, adaptedArg);
13931393

1394-
// Check if we need dynamic or static reshape
1395-
bool anyDynamic = false;
1396-
for (auto dim : intermediateShape) {
1397-
if (dim == ShapedType::kDynamic) {
1398-
anyDynamic = true;
1399-
break;
1400-
}
1401-
}
1402-
1403-
if (anyDynamic) {
1404-
// Use dynamic reshape
1405-
SmallVector<Value> shapeValues;
1406-
for (size_t i = 0; i < targetShape.size(); ++i) {
1407-
if (targetShape[i] == ShapedType::kDynamic) {
1408-
// Get dynamic dimension from original tensor
1409-
auto dimValue = builder.create<stablehlo::GetDimensionSizeOp>(
1410-
loc, scalarI32Type, adaptedArg, i);
1411-
shapeValues.push_back(dimValue);
1412-
} else {
1413-
auto constValue = builder.create<stablehlo::ConstantOp>(
1414-
loc, scalarI32Type,
1415-
cast<ElementsAttr>(makeAttr(scalarI32Type, targetShape[i])));
1416-
shapeValues.push_back(constValue);
1417-
}
1418-
}
1419-
auto shapeOp = builder.create<stablehlo::ConcatenateOp>(
1420-
loc, shapeValues, 0);
1421-
res = builder.create<stablehlo::DynamicReshapeOp>(
1422-
loc, RankedTensorType::get(targetShape, targetElemType), res, shapeOp);
1423-
} else {
1424-
// Use static reshape
1425-
res = builder.create<stablehlo::ReshapeOp>(
1426-
loc, RankedTensorType::get(targetShape, targetElemType), res);
1394+
// Always use dynamic reshape with GetDimensionSizeOp (will be optimized away for static shapes)
1395+
SmallVector<Value> shapeValues;
1396+
for (size_t i = 0; i < targetShape.size(); ++i) {
1397+
auto dimValue = builder.create<stablehlo::GetDimensionSizeOp>(
1398+
loc, res, i);
1399+
shapeValues.push_back(dimValue);
14271400
}
1401+
auto shapeOp = builder.create<stablehlo::ConcatenateOp>(
1402+
loc, shapeValues, 0);
1403+
res = builder.create<stablehlo::DynamicReshapeOp>(
1404+
loc, RankedTensorType::get(targetShape, targetElemType), res, shapeOp);
14281405
} else {
14291406
// Target element is larger: reshape first, then bitcast
14301407
assert(targetSizeBytes % currentSizeBytes == 0 &&
@@ -1440,48 +1417,25 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
14401417
intermediateShape[lastIdx - 1] /= sizeRatio;
14411418
}
14421419

1443-
Value reshaped;
1444-
// Check if we need dynamic reshape
1445-
bool anyDynamic = false;
1446-
for (auto dim : intermediateShape) {
1447-
if (dim == ShapedType::kDynamic) {
1448-
anyDynamic = true;
1449-
break;
1450-
}
1451-
}
1452-
1453-
if (anyDynamic) {
1454-
// Use dynamic reshape
1455-
SmallVector<Value> shapeValues;
1456-
for (size_t i = 0; i < intermediateShape.size(); ++i) {
1457-
if (intermediateShape[i] == ShapedType::kDynamic) {
1458-
if (i < currentShape.size()) {
1459-
auto dimValue = builder.create<stablehlo::GetDimensionSizeOp>(
1460-
loc, scalarI32Type, adaptedArg, i);
1461-
shapeValues.push_back(dimValue);
1462-
} else {
1463-
// This is the added dimension
1464-
auto constValue = builder.create<stablehlo::ConstantOp>(
1465-
loc, scalarI32Type,
1466-
cast<ElementsAttr>(makeAttr(scalarI32Type, sizeRatio)));
1467-
shapeValues.push_back(constValue);
1468-
}
1469-
} else {
1470-
auto constValue = builder.create<stablehlo::ConstantOp>(
1471-
loc, scalarI32Type,
1472-
cast<ElementsAttr>(makeAttr(scalarI32Type, intermediateShape[i])));
1473-
shapeValues.push_back(constValue);
1474-
}
1420+
// Always use dynamic reshape with GetDimensionSizeOp (will be optimized away for static shapes)
1421+
SmallVector<Value> shapeValues;
1422+
for (size_t i = 0; i < intermediateShape.size(); ++i) {
1423+
if (i < currentShape.size()) {
1424+
auto dimValue = builder.create<stablehlo::GetDimensionSizeOp>(
1425+
loc, adaptedArg, i);
1426+
shapeValues.push_back(dimValue);
1427+
} else {
1428+
// This is the added dimension
1429+
auto constValue = builder.create<stablehlo::ConstantOp>(
1430+
loc, cast<ElementsAttr>(makeAttr(scalarI32Type, sizeRatio)));
1431+
shapeValues.push_back(constValue);
14751432
}
1476-
auto shapeOp = builder.create<stablehlo::ConcatenateOp>(
1477-
loc, shapeValues, 0);
1478-
reshaped = builder.create<stablehlo::DynamicReshapeOp>(
1479-
loc, RankedTensorType::get(intermediateShape, currentElemType),
1480-
adaptedArg, shapeOp);
1481-
} else {
1482-
reshaped = builder.create<stablehlo::ReshapeOp>(
1483-
loc, RankedTensorType::get(intermediateShape, currentElemType), adaptedArg);
14841433
}
1434+
auto shapeOp = builder.create<stablehlo::ConcatenateOp>(
1435+
loc, shapeValues, 0);
1436+
Value reshaped = builder.create<stablehlo::DynamicReshapeOp>(
1437+
loc, RankedTensorType::get(intermediateShape, currentElemType),
1438+
adaptedArg, shapeOp);
14851439

14861440
// Now bitcast to target type
14871441
res = builder.create<stablehlo::BitcastConvertOp>(

0 commit comments

Comments
 (0)