55#include " mlir/IR/BuiltinOps.h"
66#include " mlir/Interfaces/FunctionInterfaces.h"
77#include " mlir/Pass/Pass.h"
8+ #include " llvm/ADT/BitVector.h"
89
10+ #include " mlir/Dialect/LLVMIR/LLVMDialect.h"
11+ #include " mlir/Dialect/MemRef/IR/MemRef.h"
912#include " src/enzyme_ad/jax/Dialect/Ops.h"
1013#include " stablehlo/dialect/StablehloOps.h"
14+ #include " triton/Dialect/Triton/IR/Dialect.h"
15+
16+ #include < queue>
1117
1218namespace mlir {
1319namespace enzyme {
@@ -50,33 +56,40 @@ struct MarkFunctionMemoryEffectsPass
5056 }
5157
5258 void
53- insertMemoryEffects (SmallVector< uint8_t , 4 > &effects,
59+ insertMemoryEffects (BitVector &effects,
5460 SmallVector<MemoryEffects::EffectInstance> memEffects) {
5561 for (auto &effect : memEffects)
5662 insertMemoryEffects (effects, effect);
5763 }
5864
59- void insertMemoryEffects (SmallVector<uint8_t , 4 > &effects) {
60- for (int i = 0 ; i < effects.size (); i++)
61- effects[i] = 1 ;
65+ void insertMemoryEffects (BitVector &effects) {
66+ effects.set (0 , effects.size ());
67+ }
68+
69+ void insertMemoryEffects (BitVector &effects, BitVector &argEffects) {
70+ for (int i = 0 ; i < effects.size (); i++) {
71+ if (argEffects[i]) {
72+ effects.set (i);
73+ }
74+ }
6275 }
6376
64- void insertMemoryEffects (SmallVector< uint8_t , 4 > &effects,
77+ void insertMemoryEffects (BitVector &effects,
6578 MemoryEffects::EffectInstance effect) {
6679 if (effect.getEffect () == MemoryEffects::Read::get ()) {
67- effects[ 0 ] = 1 ;
80+ effects. set ( 0 ) ;
6881 } else if (effect.getEffect () == MemoryEffects::Write::get ()) {
69- effects[ 1 ] = 1 ;
82+ effects. set ( 1 ) ;
7083 } else if (effect.getEffect () == MemoryEffects::Allocate::get ()) {
71- effects[ 2 ] = 1 ;
84+ effects. set ( 2 ) ;
7285 } else if (effect.getEffect () == MemoryEffects::Free::get ()) {
73- effects[ 3 ] = 1 ;
86+ effects. set ( 3 ) ;
7487 } else {
7588 assert (false && " unknown memory effect" );
7689 }
7790 }
7891
79- int64_t getNumEffects (SmallVector< uint8_t , 4 > &effects) {
92+ int64_t getNumEffects (BitVector &effects) {
8093 int64_t numEffects = 0 ;
8194 for (int i = 0 ; i < effects.size (); i++) {
8295 if (effects[i])
@@ -85,12 +98,171 @@ struct MarkFunctionMemoryEffectsPass
8598 return numEffects;
8699 }
87100
101+ struct EffectInfo {
102+ ArrayAttr enzymexlaEffects;
103+ bool readOnly;
104+ bool writeOnly;
105+ bool readNone;
106+ };
107+
108+ EffectInfo getEffectInfo (OpBuilder &builder, BitVector &effects) {
109+ EffectInfo info;
110+ info.readOnly = effects[0 ];
111+ info.writeOnly = effects[1 ];
112+ info.readNone = !effects[0 ] && !effects[1 ];
113+ SmallVector<Attribute> effectsAttrs;
114+
115+ if (effects[0 ]) {
116+ info.writeOnly = false ;
117+ effectsAttrs.push_back (builder.getStringAttr (" read" ));
118+ }
119+
120+ if (effects[1 ]) {
121+ info.readOnly = false ;
122+ effectsAttrs.push_back (builder.getStringAttr (" write" ));
123+ }
124+
125+ if (effects[2 ]) {
126+ info.writeOnly = false ;
127+ info.readOnly = false ;
128+ info.readNone = false ;
129+ effectsAttrs.push_back (builder.getStringAttr (" allocate" ));
130+ }
131+
132+ if (effects[3 ]) {
133+ info.writeOnly = false ;
134+ info.readOnly = false ;
135+ info.readNone = false ;
136+ effectsAttrs.push_back (builder.getStringAttr (" free" ));
137+ }
138+
139+ info.enzymexlaEffects = builder.getArrayAttr (effectsAttrs);
140+ return info;
141+ }
142+
143+ int32_t getArgIndex (CallOpInterface callOp, OpOperand *operand) {
144+ auto callOperands = callOp.getArgOperands ();
145+ for (unsigned i = 0 ; i < callOperands.size (); i++) {
146+ if (callOperands[i] == operand->get ())
147+ return i;
148+ }
149+ assert (false && " operand not found" );
150+ return -1 ;
151+ }
152+
153+ // TODO: at some point, we should reuse pre-existing attributes (see
154+ // jitcallsideeffect2.mlir)
155+ void handleCallOpInterface (
156+ CallOpInterface callOp, OpOperand *operand, BitVector &effects,
157+ DenseMap<SymbolRefAttr, SmallVector<BitVector>> &funcArgEffects) {
158+ if (auto calleeAttr = callOp.getCallableForCallee ()) {
159+ if (auto symRef = dyn_cast<SymbolRefAttr>(calleeAttr)) {
160+ if (funcArgEffects.contains (symRef)) {
161+ auto &argEffects = funcArgEffects[symRef];
162+ insertMemoryEffects (effects,
163+ argEffects[getArgIndex (callOp, operand)]);
164+ return ;
165+ } else {
166+ insertMemoryEffects (effects);
167+ return ;
168+ }
169+ }
170+ } else {
171+ insertMemoryEffects (effects);
172+ }
173+ }
174+
175+ bool isPointerType (Value v) { return isPointerType (v.getType ()); }
176+
177+ bool isPointerType (Type t) {
178+ return isa<LLVM::LLVMPointerType, MemRefType, triton::PointerType>(t);
179+ }
180+
181+ void analyzeMemoryEffects (
182+ Operation *op, OpOperand *operand, BitVector &effects,
183+ DenseMap<SymbolRefAttr, SmallVector<BitVector>> &funcArgEffects) {
184+ auto memEffectsOrNothing = getEffectsRecursively (op);
185+ if (!memEffectsOrNothing.has_value ()) {
186+ insertMemoryEffects (effects);
187+ return ;
188+ }
189+ auto &memEffects = memEffectsOrNothing.value ();
190+
191+ for (const auto &effect : memEffects) {
192+ if (effect.getValue () && effect.getValue () == operand->get ()) {
193+ if (isa<MemoryEffects::Read>(effect.getEffect ())) {
194+ effects.set (0 );
195+ } else if (isa<MemoryEffects::Write>(effect.getEffect ())) {
196+ effects.set (1 );
197+ } else if (isa<MemoryEffects::Allocate>(effect.getEffect ())) {
198+ effects.set (2 );
199+ } else if (isa<MemoryEffects::Free>(effect.getEffect ())) {
200+ effects.set (3 );
201+ } else {
202+ assert (false && " unknown memory effect" );
203+ }
204+ }
205+ }
206+ }
207+
208+ void analyzeFunctionArgumentMemoryEffects (
209+ FunctionOpInterface funcOp, SmallVector<BitVector> &argEffects,
210+ DenseMap<SymbolRefAttr, SmallVector<BitVector>> &funcArgEffects) {
211+ auto *ctx = funcOp->getContext ();
212+ OpBuilder builder (ctx);
213+
214+ DenseMap<Value, unsigned > valueToArgIndex;
215+ for (unsigned i = 0 ; i < funcOp.getNumArguments (); i++) {
216+ valueToArgIndex[funcOp.getArgument (i)] = i;
217+ }
218+
219+ // BFS traversal starting from arguments
220+ std::queue<Value> worklist;
221+ DenseSet<Value> visited;
222+ for (unsigned i = 0 ; i < funcOp.getNumArguments (); i++) {
223+ Value arg = funcOp.getArgument (i);
224+ worklist.push (arg);
225+ visited.insert (arg);
226+ }
227+
228+ // BFS through the graph
229+ while (!worklist.empty ()) {
230+ Value cur = worklist.front ();
231+ worklist.pop ();
232+
233+ auto argIt = valueToArgIndex.find (cur);
234+ if (argIt == valueToArgIndex.end ())
235+ continue ;
236+ unsigned argIndex = argIt->second ;
237+
238+ for (OpOperand &use : cur.getUses ()) {
239+ Operation *user = use.getOwner ();
240+
241+ if (auto callOp = dyn_cast<CallOpInterface>(user)) {
242+ handleCallOpInterface (callOp, &use, argEffects[argIndex],
243+ funcArgEffects);
244+ } else {
245+ analyzeMemoryEffects (user, &use, argEffects[argIndex],
246+ funcArgEffects);
247+ }
248+
249+ for (auto result : user->getResults ()) {
250+ if (visited.insert (result).second ) {
251+ valueToArgIndex[result] = argIndex;
252+ worklist.push (result);
253+ }
254+ }
255+ }
256+ }
257+ }
258+
88259 void runOnOperation () override {
89260 ModuleOp module = getOperation ();
90261 auto *ctx = module ->getContext ();
91262 OpBuilder builder (ctx);
92263
93- DenseMap<SymbolRefAttr, SmallVector<uint8_t , 4 >> funcEffects;
264+ DenseMap<SymbolRefAttr, BitVector> funcEffects;
265+ DenseMap<SymbolRefAttr, SmallVector<BitVector>> funcArgEffects;
94266
95267 CallGraph callGraph (module );
96268
@@ -114,7 +286,12 @@ struct MarkFunctionMemoryEffectsPass
114286 if (!funcOp)
115287 return signalPassFailure ();
116288
117- SmallVector<uint8_t , 4 > effects (4 , 0 );
289+ BitVector effects (4 , 0 );
290+ SmallVector<BitVector> argEffects;
291+ argEffects.reserve (funcOp.getNumArguments ());
292+ for (unsigned i = 0 ; i < funcOp.getNumArguments (); i++) {
293+ argEffects.push_back (BitVector (4 , 0 ));
294+ }
118295
119296 funcOp.walk ([&](Operation *op) {
120297 if (op->hasTrait <OpTrait::HasRecursiveMemoryEffects>()) {
@@ -154,12 +331,12 @@ struct MarkFunctionMemoryEffectsPass
154331 return WalkResult::advance ();
155332 });
156333
157- funcEffects[SymbolRefAttr::get (funcOp.getOperation ())] =
158- std::move (effects);
334+ auto symRef = SymbolRefAttr::get (funcOp.getOperation ());
335+ funcEffects[symRef] = std::move (effects);
336+ funcArgEffects[symRef] = std::move (argEffects);
159337 }
160338
161- auto propagate = [&](FunctionOpInterface funcOp,
162- SmallVector<uint8_t , 4 > &effects) {
339+ auto propagate = [&](FunctionOpInterface funcOp, BitVector &effects) {
163340 funcOp.walk ([&](Operation *op) {
164341 if (auto callOp = dyn_cast<CallOpInterface>(op)) {
165342 if (auto calleeAttr = callOp.getCallableForCallee ()) {
@@ -168,7 +345,7 @@ struct MarkFunctionMemoryEffectsPass
168345 auto funcEffectsSymRef = funcEffects.lookup (symRef);
169346 for (int i = 0 ; i < funcEffectsSymRef.size (); i++) {
170347 if (funcEffectsSymRef[i])
171- effects[i] = 1 ;
348+ effects. set (i) ;
172349 }
173350 }
174351 }
@@ -197,8 +374,10 @@ struct MarkFunctionMemoryEffectsPass
197374 if (!funcOp)
198375 continue ;
199376
200- auto &effects =
201- funcEffects[SymbolRefAttr::get (ctx, funcOp.getName ())];
377+ auto symRef = SymbolRefAttr::get (ctx, funcOp.getName ());
378+ analyzeFunctionArgumentMemoryEffects (funcOp, funcArgEffects[symRef],
379+ funcArgEffects);
380+ auto &effects = funcEffects[symRef];
202381 size_t before = getNumEffects (effects);
203382 propagate (funcOp, effects);
204383 changed = getNumEffects (effects) != before;
@@ -211,8 +390,7 @@ struct MarkFunctionMemoryEffectsPass
211390 insertMemoryEffects (effects);
212391 }
213392 } else {
214- // No cycles: reverse topological order and propagate
215- for (CallGraphNode *node : llvm::reverse (topoOrder)) {
393+ for (CallGraphNode *node : topoOrder) {
216394 if (node->isExternal ())
217395 continue ;
218396
@@ -225,7 +403,10 @@ struct MarkFunctionMemoryEffectsPass
225403 if (!funcOp)
226404 continue ;
227405
228- auto &effects = funcEffects[SymbolRefAttr::get (ctx, funcOp.getName ())];
406+ auto symRef = SymbolRefAttr::get (ctx, funcOp.getName ());
407+ analyzeFunctionArgumentMemoryEffects (funcOp, funcArgEffects[symRef],
408+ funcArgEffects);
409+ auto &effects = funcEffects[symRef];
229410 propagate (funcOp, effects);
230411 }
231412 }
@@ -237,26 +418,37 @@ struct MarkFunctionMemoryEffectsPass
237418 if (!funcOp)
238419 continue ;
239420
240- SmallVector<Attribute> effectsAttrs;
241- for (int i = 0 ; i < effectsSet.size (); i++) {
242- if (effectsSet[i]) {
243- if (i == 0 ) {
244- effectsAttrs.push_back (builder.getStringAttr (" read" ));
245- } else if (i == 1 ) {
246- effectsAttrs.push_back (builder.getStringAttr (" write" ));
247- } else if (i == 2 ) {
248- effectsAttrs.push_back (builder.getStringAttr (" allocate" ));
249- } else if (i == 3 ) {
250- effectsAttrs.push_back (builder.getStringAttr (" free" ));
251- } else {
252- assert (false && " unknown memory effect" );
421+ auto funcEffectInfo = getEffectInfo (builder, effectsSet);
422+ funcOp->setAttr (" enzymexla.memory_effects" ,
423+ funcEffectInfo.enzymexlaEffects );
424+
425+ auto &argEffects = funcArgEffects[symbol];
426+ for (unsigned i = 0 ; i < funcOp.getNumArguments (); i++) {
427+ auto argEffectInfo = getEffectInfo (builder, argEffects[i]);
428+ funcOp.setArgAttr (i, " enzymexla.memory_effects" ,
429+ argEffectInfo.enzymexlaEffects );
430+
431+ if (isPointerType (funcOp.getArgument (i))) {
432+ if (argEffectInfo.readOnly ) {
433+ funcOp.setArgAttr (i, LLVM::LLVMDialect::getReadonlyAttrName (),
434+ builder.getUnitAttr ());
435+ }
436+ if (argEffectInfo.writeOnly ) {
437+ funcOp.setArgAttr (i, LLVM::LLVMDialect::getWriteOnlyAttrName (),
438+ builder.getUnitAttr ());
439+ }
440+ // if (argEffectInfo.readNone) {
441+ // funcOp.setArgAttr(i, LLVM::LLVMDialect::getReadnoneAttrName(),
442+ // builder.getUnitAttr());
443+ // }
444+ if (!argEffects[i][3 ]) {
445+ funcOp.setArgAttr (i, LLVM::LLVMDialect::getNoFreeAttrName (),
446+ builder.getUnitAttr ());
253447 }
254448 }
255449 }
256-
257- funcOp->setAttr (" enzymexla.memory_effects" ,
258- builder.getArrayAttr (effectsAttrs));
259450 }
260451 }
261452};
453+
262454} // namespace
0 commit comments