Skip to content

Commit f132335

Browse files
Copilotwsmoses
andcommitted
Fix code formatting with clang-format
Run clang-format on Utils.cpp and Utils.h to fix formatting issues Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com>
1 parent 33cf3f6 commit f132335

File tree

2 files changed

+58
-46
lines changed

2 files changed

+58
-46
lines changed

src/enzyme_ad/jax/Utils.cpp

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,9 +1243,10 @@ bool mayReadFrom(Operation *op, Value val) {
12431243
return true;
12441244
}
12451245

1246-
mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
1247-
ArrayRef<mlir::Type> inputTensorTypes,
1248-
ArrayRef<int64_t> byteOffsets) {
1246+
mlir::func::FuncOp
1247+
adaptToCallingConvention(mlir::func::FuncOp f,
1248+
ArrayRef<mlir::Type> inputTensorTypes,
1249+
ArrayRef<int64_t> byteOffsets) {
12491250
// Get the original function type
12501251
auto originalFuncType = f.getFunctionType();
12511252
size_t numInputs = originalFuncType.getNumInputs();
@@ -1259,69 +1260,70 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
12591260
// Create the new function type using the outer specification types
12601261
auto context = f.getContext();
12611262
auto loc = f.getLoc();
1262-
auto newFuncType = mlir::FunctionType::get(
1263-
context, inputTensorTypes, originalFuncType.getResults());
1264-
1263+
auto newFuncType = mlir::FunctionType::get(context, inputTensorTypes,
1264+
originalFuncType.getResults());
1265+
12651266
// Create a new function with a unique name
12661267
std::string wrapperName = (f.getName() + "_adapted").str();
12671268
OpBuilder builder(context);
12681269
builder.setInsertionPoint(f);
1269-
1270-
auto wrapperFunc = builder.create<mlir::func::FuncOp>(loc, wrapperName, newFuncType);
1271-
1270+
1271+
auto wrapperFunc =
1272+
builder.create<mlir::func::FuncOp>(loc, wrapperName, newFuncType);
1273+
12721274
// Add entry block to the wrapper function
12731275
auto &entryBlock = *wrapperFunc.addEntryBlock();
12741276
builder.setInsertionPointToStart(&entryBlock);
1275-
1277+
12761278
// Process each argument
12771279
SmallVector<Value> adaptedArgs;
12781280
for (size_t i = 0; i < numInputs; ++i) {
12791281
Value arg = entryBlock.getArgument(i);
12801282
auto outerType = dyn_cast<RankedTensorType>(inputTensorTypes[i]);
12811283
auto innerType = dyn_cast<RankedTensorType>(originalFuncType.getInput(i));
1282-
1284+
12831285
if (!outerType || !innerType) {
12841286
// If not tensor types, pass through as-is
12851287
adaptedArgs.push_back(arg);
12861288
continue;
12871289
}
1288-
1290+
12891291
Value adaptedArg = arg;
1290-
1292+
12911293
// Handle byte offset if non-zero
12921294
int64_t byteOffset = byteOffsets[i];
12931295
if (byteOffset != 0) {
12941296
// Calculate element offset from byte offset
12951297
auto elementType = outerType.getElementType();
12961298

12971299
// Get element size in bytes using AutoDiffTypeInterface
1298-
size_t elementSizeBytes =
1300+
size_t elementSizeBytes =
12991301
cast<AutoDiffTypeInterface>(elementType).getApproxSize();
13001302

13011303
// Verify byte offset aligns with element boundaries
13021304
assert(byteOffset % elementSizeBytes == 0 &&
13031305
"Byte offset must be aligned to element boundaries");
13041306

13051307
int64_t elementOffset = byteOffset / elementSizeBytes;
1306-
1308+
13071309
auto outerShape = outerType.getShape();
13081310
auto innerShape = innerType.getShape();
1309-
1311+
13101312
// Convert linear element offset to multi-dimensional start indices
13111313
SmallVector<int64_t> startIndices;
13121314
SmallVector<int64_t> limitIndices;
13131315
SmallVector<int64_t> strides(outerShape.size(), 1);
1314-
1316+
13151317
int64_t remainingOffset = elementOffset;
1316-
1318+
13171319
// Calculate strides for each dimension (row-major order)
13181320
for (size_t j = 0; j < outerShape.size(); ++j) {
13191321
// Calculate the stride for this dimension
13201322
int64_t dimStride = 1;
13211323
for (size_t k = j + 1; k < outerShape.size(); ++k) {
13221324
dimStride *= outerShape[k];
13231325
}
1324-
1326+
13251327
// Calculate the index for this dimension
13261328
int64_t dimIndex = remainingOffset / dimStride;
13271329
startIndices.push_back(dimIndex);
@@ -1338,25 +1340,26 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
13381340
// Update remaining offset for next dimension
13391341
remainingOffset = remainingOffset % dimStride;
13401342
}
1341-
1342-
auto slicedType = RankedTensorType::get(innerShape, outerType.getElementType());
1343+
1344+
auto slicedType =
1345+
RankedTensorType::get(innerShape, outerType.getElementType());
13431346
adaptedArg = builder.create<stablehlo::SliceOp>(
13441347
loc, slicedType, adaptedArg,
13451348
builder.getDenseI64ArrayAttr(startIndices),
13461349
builder.getDenseI64ArrayAttr(limitIndices),
13471350
builder.getDenseI64ArrayAttr(strides));
13481351
}
1349-
1352+
13501353
// Handle element type conversion if needed using bitcast_convert
13511354
auto currentType = cast<RankedTensorType>(adaptedArg.getType());
13521355
if (currentType.getElementType() != innerType.getElementType()) {
13531356
auto currentElemType = currentType.getElementType();
13541357
auto targetElemType = innerType.getElementType();
13551358

13561359
// Calculate element sizes in bytes using AutoDiffTypeInterface
1357-
size_t currentSizeBytes =
1360+
size_t currentSizeBytes =
13581361
cast<AutoDiffTypeInterface>(currentElemType).getApproxSize();
1359-
size_t targetSizeBytes =
1362+
size_t targetSizeBytes =
13601363
cast<AutoDiffTypeInterface>(targetElemType).getApproxSize();
13611364

13621365
assert(currentSizeBytes > 0 && targetSizeBytes > 0 &&
@@ -1365,14 +1368,15 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
13651368
Value res;
13661369
auto currentShape = currentType.getShape();
13671370
auto targetShape = innerType.getShape();
1368-
1371+
13691372
// Scalar i32 tensor type for shape constants
13701373
auto scalarI32Type = RankedTensorType::get({}, builder.getI32Type());
13711374

13721375
if (currentSizeBytes == targetSizeBytes) {
13731376
// Same size: direct bitcast
13741377
auto convertedType = RankedTensorType::get(targetShape, targetElemType);
1375-
res = builder.create<stablehlo::BitcastConvertOp>(loc, convertedType, adaptedArg);
1378+
res = builder.create<stablehlo::BitcastConvertOp>(loc, convertedType,
1379+
adaptedArg);
13761380
} else if (targetSizeBytes < currentSizeBytes) {
13771381
// Target element is smaller: add dimension at the end
13781382
assert(currentSizeBytes % targetSizeBytes == 0 &&
@@ -1384,24 +1388,29 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
13841388
intermediateShape.push_back(sizeRatio);
13851389

13861390
// Adjust the last dimension if needed
1387-
if (lastIdx > 0 && intermediateShape[lastIdx - 1] != ShapedType::kDynamic) {
1391+
if (lastIdx > 0 &&
1392+
intermediateShape[lastIdx - 1] != ShapedType::kDynamic) {
13881393
intermediateShape[lastIdx - 1] /= sizeRatio;
13891394
}
13901395

1391-
auto intermediateType = RankedTensorType::get(intermediateShape, targetElemType);
1392-
res = builder.create<stablehlo::BitcastConvertOp>(loc, intermediateType, adaptedArg);
1396+
auto intermediateType =
1397+
RankedTensorType::get(intermediateShape, targetElemType);
1398+
res = builder.create<stablehlo::BitcastConvertOp>(loc, intermediateType,
1399+
adaptedArg);
13931400

1394-
// Always use dynamic reshape with GetDimensionSizeOp (will be optimized away for static shapes)
1401+
// Always use dynamic reshape with GetDimensionSizeOp (will be optimized
1402+
// away for static shapes)
13951403
SmallVector<Value> shapeValues;
13961404
for (size_t i = 0; i < targetShape.size(); ++i) {
1397-
auto dimValue = builder.create<stablehlo::GetDimensionSizeOp>(
1398-
loc, res, i);
1405+
auto dimValue =
1406+
builder.create<stablehlo::GetDimensionSizeOp>(loc, res, i);
13991407
shapeValues.push_back(dimValue);
14001408
}
1401-
auto shapeOp = builder.create<stablehlo::ConcatenateOp>(
1402-
loc, shapeValues, 0);
1409+
auto shapeOp =
1410+
builder.create<stablehlo::ConcatenateOp>(loc, shapeValues, 0);
14031411
res = builder.create<stablehlo::DynamicReshapeOp>(
1404-
loc, RankedTensorType::get(targetShape, targetElemType), res, shapeOp);
1412+
loc, RankedTensorType::get(targetShape, targetElemType), res,
1413+
shapeOp);
14051414
} else {
14061415
// Target element is larger: reshape first, then bitcast
14071416
assert(targetSizeBytes % currentSizeBytes == 0 &&
@@ -1413,11 +1422,13 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
14131422
intermediateShape.push_back(sizeRatio);
14141423

14151424
// Adjust the last dimension if needed
1416-
if (lastIdx > 0 && intermediateShape[lastIdx - 1] != ShapedType::kDynamic) {
1425+
if (lastIdx > 0 &&
1426+
intermediateShape[lastIdx - 1] != ShapedType::kDynamic) {
14171427
intermediateShape[lastIdx - 1] /= sizeRatio;
14181428
}
14191429

1420-
// Always use dynamic reshape with GetDimensionSizeOp (will be optimized away for static shapes)
1430+
// Always use dynamic reshape with GetDimensionSizeOp (will be optimized
1431+
// away for static shapes)
14211432
SmallVector<Value> shapeValues;
14221433
for (size_t i = 0; i < intermediateShape.size(); ++i) {
14231434
if (i < currentShape.size()) {
@@ -1431,8 +1442,8 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
14311442
shapeValues.push_back(constValue);
14321443
}
14331444
}
1434-
auto shapeOp = builder.create<stablehlo::ConcatenateOp>(
1435-
loc, shapeValues, 0);
1445+
auto shapeOp =
1446+
builder.create<stablehlo::ConcatenateOp>(loc, shapeValues, 0);
14361447
Value reshaped = builder.create<stablehlo::DynamicReshapeOp>(
14371448
loc, RankedTensorType::get(intermediateShape, currentElemType),
14381449
adaptedArg, shapeOp);
@@ -1444,16 +1455,16 @@ mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
14441455

14451456
adaptedArg = res;
14461457
}
1447-
1458+
14481459
adaptedArgs.push_back(adaptedArg);
14491460
}
1450-
1461+
14511462
// Call the original function with adapted arguments
14521463
auto callOp = builder.create<mlir::func::CallOp>(loc, f, adaptedArgs);
1453-
1464+
14541465
// Return the results
14551466
builder.create<mlir::func::ReturnOp>(loc, callOp.getResults());
1456-
1467+
14571468
return wrapperFunc;
14581469
}
14591470

src/enzyme_ad/jax/Utils.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -848,9 +848,10 @@ bool isOnlyUsedInOperation(Operation *operation, Operation *parentOp);
848848
/// \param inputTensorTypes The tensor types for the wrapper function arguments
849849
/// \param byteOffsets Byte offsets for each argument (0 means no offset)
850850
/// \return A new function that adapts the calling convention
851-
mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
852-
ArrayRef<mlir::Type> inputTensorTypes,
853-
ArrayRef<int64_t> byteOffsets);
851+
mlir::func::FuncOp
852+
adaptToCallingConvention(mlir::func::FuncOp f,
853+
ArrayRef<mlir::Type> inputTensorTypes,
854+
ArrayRef<int64_t> byteOffsets);
854855

855856
} // namespace enzyme
856857

0 commit comments

Comments
 (0)