Skip to content

Commit 4b61728

Browse files
committed
feat: mark memory effects
1 parent af0d25d commit 4b61728

File tree

3 files changed

+55
-3
lines changed

3 files changed

+55
-3
lines changed

src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ struct LUFactorizationOpLowering
159159
<< inputElementType;
160160
return rewriter.notifyMatchFailure(op, "unsupported input element type");
161161
}
162+
std::string lapackFnWrapper = lapackFn + "wrapper";
162163

163164
// Insert function declaration if not already present
164165
if (!moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(lapackFn)) {
@@ -175,6 +176,49 @@ struct LUFactorizationOpLowering
175176
LLVM::Linkage::External);
176177
}
177178

179+
if (!moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(lapackFnWrapper)) {
180+
OpBuilder::InsertionGuard guard(rewriter);
181+
rewriter.setInsertionPointToStart(moduleOp.getBody());
182+
183+
auto funcType =
184+
LLVM::LLVMFunctionType::get(llvmVoidPtrType,
185+
{llvmPtrType, llvmPtrType, llvmPtrType,
186+
llvmPtrType, llvmPtrType, llvmPtrType},
187+
false);
188+
189+
auto funcOp = rewriter.create<LLVM::LLVMFuncOp>(
190+
op.getLoc(), lapackFnWrapper, funcType, LLVM::Linkage::Private);
191+
rewriter.setInsertionPointToStart(funcOp.addEntryBlock(rewriter));
192+
193+
funcOp.setArgAttr(0, LLVM::LLVMDialect::getReadonlyAttrName(),
194+
rewriter.getUnitAttr());
195+
funcOp.setArgAttr(1, LLVM::LLVMDialect::getReadonlyAttrName(),
196+
rewriter.getUnitAttr());
197+
// 2 is read + write
198+
funcOp.setArgAttr(3, LLVM::LLVMDialect::getReadonlyAttrName(),
199+
rewriter.getUnitAttr());
200+
funcOp.setArgAttr(4, LLVM::LLVMDialect::getWriteOnlyAttrName(),
201+
rewriter.getUnitAttr());
202+
funcOp.setArgAttr(5, LLVM::LLVMDialect::getWriteOnlyAttrName(),
203+
rewriter.getUnitAttr());
204+
for (int i = 0; i < 6; i++) {
205+
funcOp.setArgAttr(i, LLVM::LLVMDialect::getNoFreeAttrName(),
206+
rewriter.getUnitAttr());
207+
}
208+
209+
auto callOp = rewriter.create<LLVM::CallOp>(
210+
op.getLoc(), TypeRange{}, SymbolRefAttr::get(ctx, lapackFn),
211+
ValueRange{
212+
funcOp.getArgument(0),
213+
funcOp.getArgument(1),
214+
funcOp.getArgument(2),
215+
funcOp.getArgument(3),
216+
funcOp.getArgument(4),
217+
funcOp.getArgument(5),
218+
});
219+
rewriter.create<LLVM::ReturnOp>(op.getLoc(), ValueRange{});
220+
}
221+
178222
// Call the LLVM function with enzymexla.jit_call
179223
SmallVector<Attribute> aliases;
180224
aliases.push_back(stablehlo::OutputOperandAliasAttr::get(
@@ -205,7 +249,7 @@ struct LUFactorizationOpLowering
205249
std::string wrapperFnName = lapackFn + std::to_string(fnNum++);
206250

207251
func::FuncOp func = createWrapperFuncOpCPULapack(
208-
rewriter, lapackFn, unbatchedInputType, unbatchedBLASPivotType,
252+
rewriter, lapackFnWrapper, unbatchedInputType, unbatchedBLASPivotType,
209253
unbatchedBLASInfoType, blasIntType, wrapperFnName, op, operandLayouts,
210254
resultLayouts, rewriter.getArrayAttr(aliases));
211255
if (!func)

test/lit_tests/linalg/lu.mlir

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@ module {
1313
// CPU-NEXT: %c = stablehlo.constant dense<-1> : tensor<i64>
1414
// CPU-NEXT: %c_0 = stablehlo.constant dense<-1> : tensor<64xi64>
1515
// CPU-NEXT: %c_1 = stablehlo.constant dense<64> : tensor<i64>
16-
// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_ (%c_1, %c_1, %arg0, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 2, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 4, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [2], operand_index = 5, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<i64>, tensor<i64>, tensor<64x64xf32>, tensor<i64>, tensor<64xi64>, tensor<i64>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor<i64>)
16+
// CPU-NEXT: %0:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_wrapper (%c_1, %c_1, %arg0, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 2, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 4, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [2], operand_index = 5, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<i64>, tensor<i64>, tensor<64x64xf32>, tensor<i64>, tensor<64xi64>, tensor<i64>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor<i64>)
1717
// CPU-NEXT: stablehlo.return %0#0, %0#1, %0#2 : tensor<64x64xf32>, tensor<64xi64>, tensor<i64>
1818
// CPU-NEXT: }
19+
// CPU-NEXT: llvm.func private @enzymexla_lapack_sgetrf_wrapper(%arg0: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg1: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg2: !llvm.ptr {llvm.nofree}, %arg3: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg4: !llvm.ptr {llvm.nofree, llvm.writeonly}, %arg5: !llvm.ptr {llvm.nofree, llvm.writeonly}) {
20+
// CPU-NEXT: llvm.call @enzymexla_lapack_sgetrf_(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
21+
// CPU-NEXT: llvm.return
22+
// CPU-NEXT: }
1923
// CPU-NEXT: llvm.func @enzymexla_lapack_sgetrf_(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr)
2024
// CPU-NEXT: func.func @main(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<64xi32>, tensor<64xi32>, tensor<i32>) {
2125
// CPU-NEXT: %c = stablehlo.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64>

test/lit_tests/linalg/lu_batched.mlir

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ module {
99
}
1010
}
1111

12+
// CPU: llvm.func private @enzymexla_lapack_sgetrf_wrapper(%arg0: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg1: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg2: !llvm.ptr {llvm.nofree}, %arg3: !llvm.ptr {llvm.nofree, llvm.readonly}, %arg4: !llvm.ptr {llvm.nofree, llvm.writeonly}, %arg5: !llvm.ptr {llvm.nofree, llvm.writeonly}) {
13+
// CPU-NEXT: llvm.call @enzymexla_lapack_sgetrf_(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
14+
// CPU-NEXT: llvm.return
15+
// CPU-NEXT: }
1216
// CPU: llvm.func @enzymexla_lapack_sgetrf_(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr)
1317
// CPU-NEXT: func.func @main(%arg0: tensor<4x3x64x64xf32>) -> (tensor<4x3x64x64xf32>, tensor<4x3x64xi32>, tensor<4x3x64xi32>, tensor<4x3xi32>) {
1418
// CPU: %c_0 = stablehlo.constant dense<1> : tensor<i32>
@@ -61,7 +65,7 @@ module {
6165
// CPU-NEXT: %3 = stablehlo.divide %iterArg, %c_2 : tensor<i64>
6266
// CPU-NEXT: %4 = stablehlo.dynamic_slice %arg0, %2, %3, %c_7, %c_7, sizes = [1, 1, 64, 64] : (tensor<4x3x64x64xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x1x64x64xf32>
6367
// CPU-NEXT: %5 = stablehlo.reshape %4 : (tensor<1x1x64x64xf32>) -> tensor<64x64xf32>
64-
// CPU-NEXT: %6:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_ (%c_1, %c_1, %5, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 2, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 4, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [2], operand_index = 5, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<i64>, tensor<i64>, tensor<64x64xf32>, tensor<i64>, tensor<64xi64>, tensor<i64>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor<i64>)
68+
// CPU-NEXT: %6:3 = enzymexla.jit_call @enzymexla_lapack_sgetrf_wrapper (%c_1, %c_1, %5, %c_1, %c_0, %c) {operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 2, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [1], operand_index = 4, operand_tuple_indices = []>, #stablehlo.output_operand_alias<output_tuple_indices = [2], operand_index = 5, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>], xla_side_effect_free} : (tensor<i64>, tensor<i64>, tensor<64x64xf32>, tensor<i64>, tensor<64xi64>, tensor<i64>) -> (tensor<64x64xf32>, tensor<64xi64>, tensor<i64>)
6569
// CPU-NEXT: %7 = stablehlo.reshape %6#0 : (tensor<64x64xf32>) -> tensor<1x1x64x64xf32>
6670
// CPU-NEXT: %8 = stablehlo.dynamic_update_slice %iterArg_8, %7, %2, %3, %c_7, %c_7 : (tensor<4x3x64x64xf32>, tensor<1x1x64x64xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<4x3x64x64xf32>
6771
// CPU-NEXT: %9 = stablehlo.reshape %6#1 : (tensor<64xi64>) -> tensor<1x1x64xi64>

0 commit comments

Comments
 (0)