@@ -453,14 +453,14 @@ generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter,
453453}
454454
455455static func::FuncOp getOrInsertFuncDecl (ConversionPatternRewriter &rewriter,
456- mlir::ModuleOp parentModuleOp ,
456+ Operation *parentSymbolTableOp ,
457457 StringRef funcName, TypeRange inTypes,
458458 TypeRange outTypes) {
459459
460460 mlir::OpBuilder::InsertionGuard insertGuard (rewriter);
461461 rewriter.setInsertionPointToStart (
462- &parentModuleOp. getRegion ().getBlocks ().front ());
463- SymbolTable st = SymbolTable (parentModuleOp );
462+ &parentSymbolTableOp-> getRegions (). front ().getBlocks ().front ());
463+ SymbolTable st = SymbolTable (parentSymbolTableOp );
464464 func::FuncOp fnOpLookup = st.lookup <func::FuncOp>(funcName);
465465 func::FuncOp fnOp;
466466 // if the function is already declared, use the existing function, don't
@@ -473,8 +473,8 @@ static func::FuncOp getOrInsertFuncDecl(ConversionPatternRewriter &rewriter,
473473 NamedAttribute funcAccess = NamedAttribute (t1, t2);
474474 FunctionType fnType =
475475 mlir::FunctionType::get (rewriter.getContext (), inTypes, outTypes);
476- fnOp = rewriter.create <func::FuncOp>(parentModuleOp. getLoc (), funcName ,
477- fnType, funcAccess);
476+ fnOp = rewriter.create <func::FuncOp>(parentSymbolTableOp-> getLoc (),
477+ funcName, fnType, funcAccess);
478478 }
479479 return fnOp;
480480}
@@ -1981,12 +1981,12 @@ struct ComputeExpOpByLUTLLVMPattern : OpConversionPattern<math::ExpOp> {
19811981
19821982 auto srcType = dyn_cast<VectorType>(adaptor.getOperand ().getType ());
19831983 StringRef funcName = " getExpBf16" ;
1984- auto moduleOp = expOp->getParentOfType <mlir::ModuleOp>();
19851984
19861985 VectorType v16bf16Ty = mlir::VectorType::get ({16 }, rewriter.getBF16Type ());
19871986 VectorType v8i64Ty = mlir::VectorType::get ({8 }, rewriter.getI64Type ());
19881987 func::FuncOp fnOp = getOrInsertFuncDecl (
1989- rewriter, moduleOp, funcName, TypeRange{v16bf16Ty}, TypeRange{v8i64Ty});
1988+ rewriter, expOp->getParentWithTrait <OpTrait::SymbolTable>(), funcName,
1989+ TypeRange{v16bf16Ty}, TypeRange{v8i64Ty});
19901990
19911991 SmallVector<Value> expOperands = {adaptor.getOperand ()};
19921992
0 commit comments