@@ -1243,9 +1243,10 @@ bool mayReadFrom(Operation *op, Value val) {
12431243 return true ;
12441244}
12451245
1246- mlir::func::FuncOp adaptToCallingConvention (mlir::func::FuncOp f,
1247- ArrayRef<mlir::Type> inputTensorTypes,
1248- ArrayRef<int64_t > byteOffsets) {
1246+ mlir::func::FuncOp
1247+ adaptToCallingConvention (mlir::func::FuncOp f,
1248+ ArrayRef<mlir::Type> inputTensorTypes,
1249+ ArrayRef<int64_t > byteOffsets) {
12491250 // Get the original function type
12501251 auto originalFuncType = f.getFunctionType ();
12511252 size_t numInputs = originalFuncType.getNumInputs ();
@@ -1259,69 +1260,70 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
12591260 // Create the new function type using the outer specification types
12601261 auto context = f.getContext ();
12611262 auto loc = f.getLoc ();
1262- auto newFuncType = mlir::FunctionType::get (
1263- context, inputTensorTypes, originalFuncType.getResults ());
1264-
1263+ auto newFuncType = mlir::FunctionType::get (context, inputTensorTypes,
1264+ originalFuncType.getResults ());
1265+
12651266 // Create a new function with a unique name
12661267 std::string wrapperName = (f.getName () + " _adapted" ).str ();
12671268 OpBuilder builder (context);
12681269 builder.setInsertionPoint (f);
1269-
1270- auto wrapperFunc = builder.create <mlir::func::FuncOp>(loc, wrapperName, newFuncType);
1271-
1270+
1271+ auto wrapperFunc =
1272+ builder.create <mlir::func::FuncOp>(loc, wrapperName, newFuncType);
1273+
12721274 // Add entry block to the wrapper function
12731275 auto &entryBlock = *wrapperFunc.addEntryBlock ();
12741276 builder.setInsertionPointToStart (&entryBlock);
1275-
1277+
12761278 // Process each argument
12771279 SmallVector<Value> adaptedArgs;
12781280 for (size_t i = 0 ; i < numInputs; ++i) {
12791281 Value arg = entryBlock.getArgument (i);
12801282 auto outerType = dyn_cast<RankedTensorType>(inputTensorTypes[i]);
12811283 auto innerType = dyn_cast<RankedTensorType>(originalFuncType.getInput (i));
1282-
1284+
12831285 if (!outerType || !innerType) {
12841286 // If not tensor types, pass through as-is
12851287 adaptedArgs.push_back (arg);
12861288 continue ;
12871289 }
1288-
1290+
12891291 Value adaptedArg = arg;
1290-
1292+
12911293 // Handle byte offset if non-zero
12921294 int64_t byteOffset = byteOffsets[i];
12931295 if (byteOffset != 0 ) {
12941296 // Calculate element offset from byte offset
12951297 auto elementType = outerType.getElementType ();
12961298
12971299 // Get element size in bytes using AutoDiffTypeInterface
1298- size_t elementSizeBytes =
1300+ size_t elementSizeBytes =
12991301 cast<AutoDiffTypeInterface>(elementType).getApproxSize ();
13001302
13011303 // Verify byte offset aligns with element boundaries
13021304 assert (byteOffset % elementSizeBytes == 0 &&
13031305 " Byte offset must be aligned to element boundaries" );
13041306
13051307 int64_t elementOffset = byteOffset / elementSizeBytes;
1306-
1308+
13071309 auto outerShape = outerType.getShape ();
13081310 auto innerShape = innerType.getShape ();
1309-
1311+
13101312 // Convert linear element offset to multi-dimensional start indices
13111313 SmallVector<int64_t > startIndices;
13121314 SmallVector<int64_t > limitIndices;
13131315 SmallVector<int64_t > strides (outerShape.size (), 1 );
1314-
1316+
13151317 int64_t remainingOffset = elementOffset;
1316-
1318+
13171319 // Calculate strides for each dimension (row-major order)
13181320 for (size_t j = 0 ; j < outerShape.size (); ++j) {
13191321 // Calculate the stride for this dimension
13201322 int64_t dimStride = 1 ;
13211323 for (size_t k = j + 1 ; k < outerShape.size (); ++k) {
13221324 dimStride *= outerShape[k];
13231325 }
1324-
1326+
13251327 // Calculate the index for this dimension
13261328 int64_t dimIndex = remainingOffset / dimStride;
13271329 startIndices.push_back (dimIndex);
@@ -1338,25 +1340,26 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
13381340 // Update remaining offset for next dimension
13391341 remainingOffset = remainingOffset % dimStride;
13401342 }
1341-
1342- auto slicedType = RankedTensorType::get (innerShape, outerType.getElementType ());
1343+
1344+ auto slicedType =
1345+ RankedTensorType::get (innerShape, outerType.getElementType ());
13431346 adaptedArg = builder.create <stablehlo::SliceOp>(
13441347 loc, slicedType, adaptedArg,
13451348 builder.getDenseI64ArrayAttr (startIndices),
13461349 builder.getDenseI64ArrayAttr (limitIndices),
13471350 builder.getDenseI64ArrayAttr (strides));
13481351 }
1349-
1352+
13501353 // Handle element type conversion if needed using bitcast_convert
13511354 auto currentType = cast<RankedTensorType>(adaptedArg.getType ());
13521355 if (currentType.getElementType () != innerType.getElementType ()) {
13531356 auto currentElemType = currentType.getElementType ();
13541357 auto targetElemType = innerType.getElementType ();
13551358
13561359 // Calculate element sizes in bytes using AutoDiffTypeInterface
1357- size_t currentSizeBytes =
1360+ size_t currentSizeBytes =
13581361 cast<AutoDiffTypeInterface>(currentElemType).getApproxSize ();
1359- size_t targetSizeBytes =
1362+ size_t targetSizeBytes =
13601363 cast<AutoDiffTypeInterface>(targetElemType).getApproxSize ();
13611364
13621365 assert (currentSizeBytes > 0 && targetSizeBytes > 0 &&
@@ -1365,14 +1368,15 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
13651368 Value res;
13661369 auto currentShape = currentType.getShape ();
13671370 auto targetShape = innerType.getShape ();
1368-
1371+
13691372 // Scalar i32 tensor type for shape constants
13701373 auto scalarI32Type = RankedTensorType::get ({}, builder.getI32Type ());
13711374
13721375 if (currentSizeBytes == targetSizeBytes) {
13731376 // Same size: direct bitcast
13741377 auto convertedType = RankedTensorType::get (targetShape, targetElemType);
1375- res = builder.create <stablehlo::BitcastConvertOp>(loc, convertedType, adaptedArg);
1378+ res = builder.create <stablehlo::BitcastConvertOp>(loc, convertedType,
1379+ adaptedArg);
13761380 } else if (targetSizeBytes < currentSizeBytes) {
13771381 // Target element is smaller: add dimension at the end
13781382 assert (currentSizeBytes % targetSizeBytes == 0 &&
@@ -1384,24 +1388,29 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
13841388 intermediateShape.push_back (sizeRatio);
13851389
13861390 // Adjust the last dimension if needed
1387- if (lastIdx > 0 && intermediateShape[lastIdx - 1 ] != ShapedType::kDynamic ) {
1391+ if (lastIdx > 0 &&
1392+ intermediateShape[lastIdx - 1 ] != ShapedType::kDynamic ) {
13881393 intermediateShape[lastIdx - 1 ] /= sizeRatio;
13891394 }
13901395
1391- auto intermediateType = RankedTensorType::get (intermediateShape, targetElemType);
1392- res = builder.create <stablehlo::BitcastConvertOp>(loc, intermediateType, adaptedArg);
1396+ auto intermediateType =
1397+ RankedTensorType::get (intermediateShape, targetElemType);
1398+ res = builder.create <stablehlo::BitcastConvertOp>(loc, intermediateType,
1399+ adaptedArg);
13931400
1394- // Always use dynamic reshape with GetDimensionSizeOp (will be optimized away for static shapes)
1401+ // Always use dynamic reshape with GetDimensionSizeOp (will be optimized
1402+ // away for static shapes)
13951403 SmallVector<Value> shapeValues;
13961404 for (size_t i = 0 ; i < targetShape.size (); ++i) {
1397- auto dimValue = builder. create <stablehlo::GetDimensionSizeOp>(
1398- loc, res, i);
1405+ auto dimValue =
1406+ builder. create <stablehlo::GetDimensionSizeOp>( loc, res, i);
13991407 shapeValues.push_back (dimValue);
14001408 }
1401- auto shapeOp = builder. create <stablehlo::ConcatenateOp>(
1402- loc, shapeValues, 0 );
1409+ auto shapeOp =
1410+ builder. create <stablehlo::ConcatenateOp>( loc, shapeValues, 0 );
14031411 res = builder.create <stablehlo::DynamicReshapeOp>(
1404- loc, RankedTensorType::get (targetShape, targetElemType), res, shapeOp);
1412+ loc, RankedTensorType::get (targetShape, targetElemType), res,
1413+ shapeOp);
14051414 } else {
14061415 // Target element is larger: reshape first, then bitcast
14071416 assert (targetSizeBytes % currentSizeBytes == 0 &&
@@ -1413,11 +1422,13 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
14131422 intermediateShape.push_back (sizeRatio);
14141423
14151424 // Adjust the last dimension if needed
1416- if (lastIdx > 0 && intermediateShape[lastIdx - 1 ] != ShapedType::kDynamic ) {
1425+ if (lastIdx > 0 &&
1426+ intermediateShape[lastIdx - 1 ] != ShapedType::kDynamic ) {
14171427 intermediateShape[lastIdx - 1 ] /= sizeRatio;
14181428 }
14191429
1420- // Always use dynamic reshape with GetDimensionSizeOp (will be optimized away for static shapes)
1430+ // Always use dynamic reshape with GetDimensionSizeOp (will be optimized
1431+ // away for static shapes)
14211432 SmallVector<Value> shapeValues;
14221433 for (size_t i = 0 ; i < intermediateShape.size (); ++i) {
14231434 if (i < currentShape.size ()) {
@@ -1431,8 +1442,8 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
14311442 shapeValues.push_back (constValue);
14321443 }
14331444 }
1434- auto shapeOp = builder. create <stablehlo::ConcatenateOp>(
1435- loc, shapeValues, 0 );
1445+ auto shapeOp =
1446+ builder. create <stablehlo::ConcatenateOp>( loc, shapeValues, 0 );
14361447 Value reshaped = builder.create <stablehlo::DynamicReshapeOp>(
14371448 loc, RankedTensorType::get (intermediateShape, currentElemType),
14381449 adaptedArg, shapeOp);
@@ -1444,16 +1455,16 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
14441455
14451456 adaptedArg = res;
14461457 }
1447-
1458+
14481459 adaptedArgs.push_back (adaptedArg);
14491460 }
1450-
1461+
14511462 // Call the original function with adapted arguments
14521463 auto callOp = builder.create <mlir::func::CallOp>(loc, f, adaptedArgs);
1453-
1464+
14541465 // Return the results
14551466 builder.create <mlir::func::ReturnOp>(loc, callOp.getResults ());
1456-
1467+
14571468 return wrapperFunc;
14581469}
14591470
0 commit comments