Skip to content

Commit 936ed09

Browse files
authored
feat: function argument memory effects for kernels (#1439)
* feat: function argument memory effects for kernels * chore: run fmt * feat: use llvm dialect attributes * feat: merge the arg effects pass into func effects * fix: correct use of CallOpInterface + only mark llvm attrs for ptrs * fix: avoid readnone for now (drop me)
1 parent 7f88655 commit 936ed09

File tree

8 files changed

+283
-59
lines changed

8 files changed

+283
-59
lines changed

src/enzyme_ad/jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,7 @@ cc_library(
786786
"@stablehlo//:stablehlo_ops",
787787
"@stablehlo//:stablehlo_passes",
788788
"@stablehlo//:stablehlo_type_inference",
789+
"@triton//:TritonDialects",
789790
"@xla//xla/mlir/utils:type_util",
790791
"@xla//xla/mlir_hlo",
791792
],

src/enzyme_ad/jax/Passes/MarkFunctionMemoryEffectsPass.cpp

Lines changed: 230 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@
55
#include "mlir/IR/BuiltinOps.h"
66
#include "mlir/Interfaces/FunctionInterfaces.h"
77
#include "mlir/Pass/Pass.h"
8+
#include "llvm/ADT/BitVector.h"
89

10+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
912
#include "src/enzyme_ad/jax/Dialect/Ops.h"
1013
#include "stablehlo/dialect/StablehloOps.h"
14+
#include "triton/Dialect/Triton/IR/Dialect.h"
15+
16+
#include <queue>
1117

1218
namespace mlir {
1319
namespace enzyme {
@@ -50,33 +56,40 @@ struct MarkFunctionMemoryEffectsPass
5056
}
5157

5258
void
53-
insertMemoryEffects(SmallVector<uint8_t, 4> &effects,
59+
insertMemoryEffects(BitVector &effects,
5460
SmallVector<MemoryEffects::EffectInstance> memEffects) {
5561
for (auto &effect : memEffects)
5662
insertMemoryEffects(effects, effect);
5763
}
5864

59-
void insertMemoryEffects(SmallVector<uint8_t, 4> &effects) {
60-
for (int i = 0; i < effects.size(); i++)
61-
effects[i] = 1;
65+
void insertMemoryEffects(BitVector &effects) {
66+
effects.set(0, effects.size());
67+
}
68+
69+
void insertMemoryEffects(BitVector &effects, BitVector &argEffects) {
70+
for (int i = 0; i < effects.size(); i++) {
71+
if (argEffects[i]) {
72+
effects.set(i);
73+
}
74+
}
6275
}
6376

64-
void insertMemoryEffects(SmallVector<uint8_t, 4> &effects,
77+
void insertMemoryEffects(BitVector &effects,
6578
MemoryEffects::EffectInstance effect) {
6679
if (effect.getEffect() == MemoryEffects::Read::get()) {
67-
effects[0] = 1;
80+
effects.set(0);
6881
} else if (effect.getEffect() == MemoryEffects::Write::get()) {
69-
effects[1] = 1;
82+
effects.set(1);
7083
} else if (effect.getEffect() == MemoryEffects::Allocate::get()) {
71-
effects[2] = 1;
84+
effects.set(2);
7285
} else if (effect.getEffect() == MemoryEffects::Free::get()) {
73-
effects[3] = 1;
86+
effects.set(3);
7487
} else {
7588
assert(false && "unknown memory effect");
7689
}
7790
}
7891

79-
int64_t getNumEffects(SmallVector<uint8_t, 4> &effects) {
92+
int64_t getNumEffects(BitVector &effects) {
8093
int64_t numEffects = 0;
8194
for (int i = 0; i < effects.size(); i++) {
8295
if (effects[i])
@@ -85,12 +98,171 @@ struct MarkFunctionMemoryEffectsPass
8598
return numEffects;
8699
}
87100

101+
struct EffectInfo {
102+
ArrayAttr enzymexlaEffects;
103+
bool readOnly;
104+
bool writeOnly;
105+
bool readNone;
106+
};
107+
108+
EffectInfo getEffectInfo(OpBuilder &builder, BitVector &effects) {
109+
EffectInfo info;
110+
info.readOnly = effects[0];
111+
info.writeOnly = effects[1];
112+
info.readNone = !effects[0] && !effects[1];
113+
SmallVector<Attribute> effectsAttrs;
114+
115+
if (effects[0]) {
116+
info.writeOnly = false;
117+
effectsAttrs.push_back(builder.getStringAttr("read"));
118+
}
119+
120+
if (effects[1]) {
121+
info.readOnly = false;
122+
effectsAttrs.push_back(builder.getStringAttr("write"));
123+
}
124+
125+
if (effects[2]) {
126+
info.writeOnly = false;
127+
info.readOnly = false;
128+
info.readNone = false;
129+
effectsAttrs.push_back(builder.getStringAttr("allocate"));
130+
}
131+
132+
if (effects[3]) {
133+
info.writeOnly = false;
134+
info.readOnly = false;
135+
info.readNone = false;
136+
effectsAttrs.push_back(builder.getStringAttr("free"));
137+
}
138+
139+
info.enzymexlaEffects = builder.getArrayAttr(effectsAttrs);
140+
return info;
141+
}
142+
143+
int32_t getArgIndex(CallOpInterface callOp, OpOperand *operand) {
144+
auto callOperands = callOp.getArgOperands();
145+
for (unsigned i = 0; i < callOperands.size(); i++) {
146+
if (callOperands[i] == operand->get())
147+
return i;
148+
}
149+
assert(false && "operand not found");
150+
return -1;
151+
}
152+
153+
// TODO: at some point, we should reuse pre-existing attributes (see
154+
// jitcallsideeffect2.mlir)
155+
void handleCallOpInterface(
156+
CallOpInterface callOp, OpOperand *operand, BitVector &effects,
157+
DenseMap<SymbolRefAttr, SmallVector<BitVector>> &funcArgEffects) {
158+
if (auto calleeAttr = callOp.getCallableForCallee()) {
159+
if (auto symRef = dyn_cast<SymbolRefAttr>(calleeAttr)) {
160+
if (funcArgEffects.contains(symRef)) {
161+
auto &argEffects = funcArgEffects[symRef];
162+
insertMemoryEffects(effects,
163+
argEffects[getArgIndex(callOp, operand)]);
164+
return;
165+
} else {
166+
insertMemoryEffects(effects);
167+
return;
168+
}
169+
}
170+
} else {
171+
insertMemoryEffects(effects);
172+
}
173+
}
174+
175+
bool isPointerType(Value v) { return isPointerType(v.getType()); }
176+
177+
bool isPointerType(Type t) {
178+
return isa<LLVM::LLVMPointerType, MemRefType, triton::PointerType>(t);
179+
}
180+
181+
void analyzeMemoryEffects(
182+
Operation *op, OpOperand *operand, BitVector &effects,
183+
DenseMap<SymbolRefAttr, SmallVector<BitVector>> &funcArgEffects) {
184+
auto memEffectsOrNothing = getEffectsRecursively(op);
185+
if (!memEffectsOrNothing.has_value()) {
186+
insertMemoryEffects(effects);
187+
return;
188+
}
189+
auto &memEffects = memEffectsOrNothing.value();
190+
191+
for (const auto &effect : memEffects) {
192+
if (effect.getValue() && effect.getValue() == operand->get()) {
193+
if (isa<MemoryEffects::Read>(effect.getEffect())) {
194+
effects.set(0);
195+
} else if (isa<MemoryEffects::Write>(effect.getEffect())) {
196+
effects.set(1);
197+
} else if (isa<MemoryEffects::Allocate>(effect.getEffect())) {
198+
effects.set(2);
199+
} else if (isa<MemoryEffects::Free>(effect.getEffect())) {
200+
effects.set(3);
201+
} else {
202+
assert(false && "unknown memory effect");
203+
}
204+
}
205+
}
206+
}
207+
208+
void analyzeFunctionArgumentMemoryEffects(
209+
FunctionOpInterface funcOp, SmallVector<BitVector> &argEffects,
210+
DenseMap<SymbolRefAttr, SmallVector<BitVector>> &funcArgEffects) {
211+
auto *ctx = funcOp->getContext();
212+
OpBuilder builder(ctx);
213+
214+
DenseMap<Value, unsigned> valueToArgIndex;
215+
for (unsigned i = 0; i < funcOp.getNumArguments(); i++) {
216+
valueToArgIndex[funcOp.getArgument(i)] = i;
217+
}
218+
219+
// BFS traversal starting from arguments
220+
std::queue<Value> worklist;
221+
DenseSet<Value> visited;
222+
for (unsigned i = 0; i < funcOp.getNumArguments(); i++) {
223+
Value arg = funcOp.getArgument(i);
224+
worklist.push(arg);
225+
visited.insert(arg);
226+
}
227+
228+
// BFS through the graph
229+
while (!worklist.empty()) {
230+
Value cur = worklist.front();
231+
worklist.pop();
232+
233+
auto argIt = valueToArgIndex.find(cur);
234+
if (argIt == valueToArgIndex.end())
235+
continue;
236+
unsigned argIndex = argIt->second;
237+
238+
for (OpOperand &use : cur.getUses()) {
239+
Operation *user = use.getOwner();
240+
241+
if (auto callOp = dyn_cast<CallOpInterface>(user)) {
242+
handleCallOpInterface(callOp, &use, argEffects[argIndex],
243+
funcArgEffects);
244+
} else {
245+
analyzeMemoryEffects(user, &use, argEffects[argIndex],
246+
funcArgEffects);
247+
}
248+
249+
for (auto result : user->getResults()) {
250+
if (visited.insert(result).second) {
251+
valueToArgIndex[result] = argIndex;
252+
worklist.push(result);
253+
}
254+
}
255+
}
256+
}
257+
}
258+
88259
void runOnOperation() override {
89260
ModuleOp module = getOperation();
90261
auto *ctx = module->getContext();
91262
OpBuilder builder(ctx);
92263

93-
DenseMap<SymbolRefAttr, SmallVector<uint8_t, 4>> funcEffects;
264+
DenseMap<SymbolRefAttr, BitVector> funcEffects;
265+
DenseMap<SymbolRefAttr, SmallVector<BitVector>> funcArgEffects;
94266

95267
CallGraph callGraph(module);
96268

@@ -114,7 +286,12 @@ struct MarkFunctionMemoryEffectsPass
114286
if (!funcOp)
115287
return signalPassFailure();
116288

117-
SmallVector<uint8_t, 4> effects(4, 0);
289+
BitVector effects(4, 0);
290+
SmallVector<BitVector> argEffects;
291+
argEffects.reserve(funcOp.getNumArguments());
292+
for (unsigned i = 0; i < funcOp.getNumArguments(); i++) {
293+
argEffects.push_back(BitVector(4, 0));
294+
}
118295

119296
funcOp.walk([&](Operation *op) {
120297
if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
@@ -154,12 +331,12 @@ struct MarkFunctionMemoryEffectsPass
154331
return WalkResult::advance();
155332
});
156333

157-
funcEffects[SymbolRefAttr::get(funcOp.getOperation())] =
158-
std::move(effects);
334+
auto symRef = SymbolRefAttr::get(funcOp.getOperation());
335+
funcEffects[symRef] = std::move(effects);
336+
funcArgEffects[symRef] = std::move(argEffects);
159337
}
160338

161-
auto propagate = [&](FunctionOpInterface funcOp,
162-
SmallVector<uint8_t, 4> &effects) {
339+
auto propagate = [&](FunctionOpInterface funcOp, BitVector &effects) {
163340
funcOp.walk([&](Operation *op) {
164341
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
165342
if (auto calleeAttr = callOp.getCallableForCallee()) {
@@ -168,7 +345,7 @@ struct MarkFunctionMemoryEffectsPass
168345
auto funcEffectsSymRef = funcEffects.lookup(symRef);
169346
for (int i = 0; i < funcEffectsSymRef.size(); i++) {
170347
if (funcEffectsSymRef[i])
171-
effects[i] = 1;
348+
effects.set(i);
172349
}
173350
}
174351
}
@@ -197,8 +374,10 @@ struct MarkFunctionMemoryEffectsPass
197374
if (!funcOp)
198375
continue;
199376

200-
auto &effects =
201-
funcEffects[SymbolRefAttr::get(ctx, funcOp.getName())];
377+
auto symRef = SymbolRefAttr::get(ctx, funcOp.getName());
378+
analyzeFunctionArgumentMemoryEffects(funcOp, funcArgEffects[symRef],
379+
funcArgEffects);
380+
auto &effects = funcEffects[symRef];
202381
size_t before = getNumEffects(effects);
203382
propagate(funcOp, effects);
204383
changed = getNumEffects(effects) != before;
@@ -211,8 +390,7 @@ struct MarkFunctionMemoryEffectsPass
211390
insertMemoryEffects(effects);
212391
}
213392
} else {
214-
// No cycles: reverse topological order and propagate
215-
for (CallGraphNode *node : llvm::reverse(topoOrder)) {
393+
for (CallGraphNode *node : topoOrder) {
216394
if (node->isExternal())
217395
continue;
218396

@@ -225,7 +403,10 @@ struct MarkFunctionMemoryEffectsPass
225403
if (!funcOp)
226404
continue;
227405

228-
auto &effects = funcEffects[SymbolRefAttr::get(ctx, funcOp.getName())];
406+
auto symRef = SymbolRefAttr::get(ctx, funcOp.getName());
407+
analyzeFunctionArgumentMemoryEffects(funcOp, funcArgEffects[symRef],
408+
funcArgEffects);
409+
auto &effects = funcEffects[symRef];
229410
propagate(funcOp, effects);
230411
}
231412
}
@@ -237,26 +418,37 @@ struct MarkFunctionMemoryEffectsPass
237418
if (!funcOp)
238419
continue;
239420

240-
SmallVector<Attribute> effectsAttrs;
241-
for (int i = 0; i < effectsSet.size(); i++) {
242-
if (effectsSet[i]) {
243-
if (i == 0) {
244-
effectsAttrs.push_back(builder.getStringAttr("read"));
245-
} else if (i == 1) {
246-
effectsAttrs.push_back(builder.getStringAttr("write"));
247-
} else if (i == 2) {
248-
effectsAttrs.push_back(builder.getStringAttr("allocate"));
249-
} else if (i == 3) {
250-
effectsAttrs.push_back(builder.getStringAttr("free"));
251-
} else {
252-
assert(false && "unknown memory effect");
421+
auto funcEffectInfo = getEffectInfo(builder, effectsSet);
422+
funcOp->setAttr("enzymexla.memory_effects",
423+
funcEffectInfo.enzymexlaEffects);
424+
425+
auto &argEffects = funcArgEffects[symbol];
426+
for (unsigned i = 0; i < funcOp.getNumArguments(); i++) {
427+
auto argEffectInfo = getEffectInfo(builder, argEffects[i]);
428+
funcOp.setArgAttr(i, "enzymexla.memory_effects",
429+
argEffectInfo.enzymexlaEffects);
430+
431+
if (isPointerType(funcOp.getArgument(i))) {
432+
if (argEffectInfo.readOnly) {
433+
funcOp.setArgAttr(i, LLVM::LLVMDialect::getReadonlyAttrName(),
434+
builder.getUnitAttr());
435+
}
436+
if (argEffectInfo.writeOnly) {
437+
funcOp.setArgAttr(i, LLVM::LLVMDialect::getWriteOnlyAttrName(),
438+
builder.getUnitAttr());
439+
}
440+
// if (argEffectInfo.readNone) {
441+
// funcOp.setArgAttr(i, LLVM::LLVMDialect::getReadnoneAttrName(),
442+
// builder.getUnitAttr());
443+
// }
444+
if (!argEffects[i][3]) {
445+
funcOp.setArgAttr(i, LLVM::LLVMDialect::getNoFreeAttrName(),
446+
builder.getUnitAttr());
253447
}
254448
}
255449
}
256-
257-
funcOp->setAttr("enzymexla.memory_effects",
258-
builder.getArrayAttr(effectsAttrs));
259450
}
260451
}
261452
};
453+
262454
} // namespace

src/enzyme_ad/jax/Passes/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ def PropagateConstantBoundsPass
6363

6464
def MarkFunctionMemoryEffectsPass : Pass<"mark-func-memory-effects", "ModuleOp"> {
6565
let summary = "Attach enzymexla.memory_effects attribute summarizing memory access";
66+
let dependentDialects = [
67+
"mlir::LLVM::LLVMDialect",
68+
"triton::TritonDialect",
69+
"memref::MemRefDialect",
70+
];
6671
let options = [
6772
Option<
6873
/*C++ variable name=*/"max_iterations",

0 commit comments

Comments
 (0)