@@ -1391,40 +1391,17 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
13911391 auto intermediateType = RankedTensorType::get (intermediateShape, targetElemType);
13921392 res = builder.create <stablehlo::BitcastConvertOp>(loc, intermediateType, adaptedArg);
13931393
1394- // Check if we need dynamic or static reshape
1395- bool anyDynamic = false ;
1396- for (auto dim : intermediateShape) {
1397- if (dim == ShapedType::kDynamic ) {
1398- anyDynamic = true ;
1399- break ;
1400- }
1401- }
1402-
1403- if (anyDynamic) {
1404- // Use dynamic reshape
1405- SmallVector<Value> shapeValues;
1406- for (size_t i = 0 ; i < targetShape.size (); ++i) {
1407- if (targetShape[i] == ShapedType::kDynamic ) {
1408- // Get dynamic dimension from original tensor
1409- auto dimValue = builder.create <stablehlo::GetDimensionSizeOp>(
1410- loc, scalarI32Type, adaptedArg, i);
1411- shapeValues.push_back (dimValue);
1412- } else {
1413- auto constValue = builder.create <stablehlo::ConstantOp>(
1414- loc, scalarI32Type,
1415- cast<ElementsAttr>(makeAttr (scalarI32Type, targetShape[i])));
1416- shapeValues.push_back (constValue);
1417- }
1418- }
1419- auto shapeOp = builder.create <stablehlo::ConcatenateOp>(
1420- loc, shapeValues, 0 );
1421- res = builder.create <stablehlo::DynamicReshapeOp>(
1422- loc, RankedTensorType::get (targetShape, targetElemType), res, shapeOp);
1423- } else {
1424- // Use static reshape
1425- res = builder.create <stablehlo::ReshapeOp>(
1426- loc, RankedTensorType::get (targetShape, targetElemType), res);
1394+ // Always use dynamic reshape with GetDimensionSizeOp (will be optimized away for static shapes)
1395+ SmallVector<Value> shapeValues;
1396+ for (size_t i = 0 ; i < targetShape.size (); ++i) {
1397+ auto dimValue = builder.create <stablehlo::GetDimensionSizeOp>(
1398+ loc, res, i);
1399+ shapeValues.push_back (dimValue);
14271400 }
1401+ auto shapeOp = builder.create <stablehlo::ConcatenateOp>(
1402+ loc, shapeValues, 0 );
1403+ res = builder.create <stablehlo::DynamicReshapeOp>(
1404+ loc, RankedTensorType::get (targetShape, targetElemType), res, shapeOp);
14281405 } else {
14291406 // Target element is larger: reshape first, then bitcast
14301407 assert (targetSizeBytes % currentSizeBytes == 0 &&
@@ -1440,48 +1417,25 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
14401417 intermediateShape[lastIdx - 1 ] /= sizeRatio;
14411418 }
14421419
1443- Value reshaped;
1444- // Check if we need dynamic reshape
1445- bool anyDynamic = false ;
1446- for (auto dim : intermediateShape) {
1447- if (dim == ShapedType::kDynamic ) {
1448- anyDynamic = true ;
1449- break ;
1450- }
1451- }
1452-
1453- if (anyDynamic) {
1454- // Use dynamic reshape
1455- SmallVector<Value> shapeValues;
1456- for (size_t i = 0 ; i < intermediateShape.size (); ++i) {
1457- if (intermediateShape[i] == ShapedType::kDynamic ) {
1458- if (i < currentShape.size ()) {
1459- auto dimValue = builder.create <stablehlo::GetDimensionSizeOp>(
1460- loc, scalarI32Type, adaptedArg, i);
1461- shapeValues.push_back (dimValue);
1462- } else {
1463- // This is the added dimension
1464- auto constValue = builder.create <stablehlo::ConstantOp>(
1465- loc, scalarI32Type,
1466- cast<ElementsAttr>(makeAttr (scalarI32Type, sizeRatio)));
1467- shapeValues.push_back (constValue);
1468- }
1469- } else {
1470- auto constValue = builder.create <stablehlo::ConstantOp>(
1471- loc, scalarI32Type,
1472- cast<ElementsAttr>(makeAttr (scalarI32Type, intermediateShape[i])));
1473- shapeValues.push_back (constValue);
1474- }
1420+ // Always use dynamic reshape with GetDimensionSizeOp (will be optimized away for static shapes)
1421+ SmallVector<Value> shapeValues;
1422+ for (size_t i = 0 ; i < intermediateShape.size (); ++i) {
1423+ if (i < currentShape.size ()) {
1424+ auto dimValue = builder.create <stablehlo::GetDimensionSizeOp>(
1425+ loc, adaptedArg, i);
1426+ shapeValues.push_back (dimValue);
1427+ } else {
1428+ // This is the added dimension
1429+ auto constValue = builder.create <stablehlo::ConstantOp>(
1430+ loc, cast<ElementsAttr>(makeAttr (scalarI32Type, sizeRatio)));
1431+ shapeValues.push_back (constValue);
14751432 }
1476- auto shapeOp = builder.create <stablehlo::ConcatenateOp>(
1477- loc, shapeValues, 0 );
1478- reshaped = builder.create <stablehlo::DynamicReshapeOp>(
1479- loc, RankedTensorType::get (intermediateShape, currentElemType),
1480- adaptedArg, shapeOp);
1481- } else {
1482- reshaped = builder.create <stablehlo::ReshapeOp>(
1483- loc, RankedTensorType::get (intermediateShape, currentElemType), adaptedArg);
14841433 }
1434+ auto shapeOp = builder.create <stablehlo::ConcatenateOp>(
1435+ loc, shapeValues, 0 );
1436+ Value reshaped = builder.create <stablehlo::DynamicReshapeOp>(
1437+ loc, RankedTensorType::get (intermediateShape, currentElemType),
1438+ adaptedArg, shapeOp);
14851439
14861440 // Now bitcast to target type
14871441 res = builder.create <stablehlo::BitcastConvertOp>(
0 commit comments