Skip to content
11 changes: 6 additions & 5 deletions src/enzyme_ad/jax/Dialect/TritonExt/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ template <>
triton_ext::TritonCallOp ReadOnlyArg<triton_ext::TritonCallOp>::create(
PatternRewriter &rewriter, triton_ext::TritonCallOp launchOp,
ArrayRef<Type> resTys, ArrayAttr outputAliases) const {
return rewriter.create<triton_ext::TritonCallOp>(
launchOp.getLoc(), resTys, launchOp.getFn(), launchOp.getGridx(),
launchOp.getGridy(), launchOp.getGridz(), launchOp.getClusterx(),
launchOp.getClustery(), launchOp.getClusterz(), launchOp.getInputs(),
launchOp.getBackendConfigAttr(), launchOp.getOperandLayoutsAttr(),
return triton_ext::TritonCallOp::create(
rewriter, launchOp.getLoc(), resTys, launchOp.getFn(),
launchOp.getGridx(), launchOp.getGridy(), launchOp.getGridz(),
launchOp.getClusterx(), launchOp.getClustery(), launchOp.getClusterz(),
launchOp.getInputs(), launchOp.getBackendConfigAttr(),
launchOp.getOperandLayoutsAttr(),
/*resultLayouts*/ nullptr, launchOp.getArgAttrsAttr(),
launchOp.getResAttrsAttr(), outputAliases,
launchOp.getXlaSideEffectFreeAttr());
Expand Down
2 changes: 1 addition & 1 deletion src/enzyme_ad/jax/Passes/LinalgUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ mlir::ArrayAttr getSHLOLayout(PatternRewriter &rewriter,
return rewriter.getArrayAttr(attrs);
}

std::optional<std::string> lapack_precision_prefix(Type elementType) {
std::optional<std::string> lapackPrecisionPrefix(Type elementType) {

// single-precision float
if (elementType.isF32()) {
Expand Down
2 changes: 1 addition & 1 deletion src/enzyme_ad/jax/Passes/LinalgUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ mlir::ArrayAttr getSHLOLayout(mlir::PatternRewriter &rewriter,
llvm::SmallVector<bool> isColMajorArr,
int64_t maxNumDims);

std::optional<std::string> lapack_precision_prefix(mlir::Type elementType);
std::optional<std::string> lapackPrecisionPrefix(mlir::Type elementType);

#endif // ENZYMEXLA_LINALGUTILS_H
91 changes: 38 additions & 53 deletions src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,20 @@ struct GeqrfOpLowering : public OpRewritePattern<enzymexla::GeqrfOp> {
LogicalResult matchAndRewrite(enzymexla::GeqrfOp op,
PatternRewriter &rewriter) const override {
if (backend == "cpu")
return this->matchAndRewrite_cpu(op, rewriter);

return matchAndRewriteCPU(op, rewriter);
else if (backend == "cuda")
return this->matchAndRewrite_cuda(op, rewriter);

return matchAndRewriteCUDA(op, rewriter);
else if (backend == "tpu")
return this->matchAndRewrite_tpu(op, rewriter);

return matchAndRewriteTPU(op, rewriter);
else
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
"\"");
}

// TODO get matrix sizes dynamically so that we don't need to create a
// function wrapper for each op instance
LogicalResult matchAndRewrite_cpu(enzymexla::GeqrfOp op,
PatternRewriter &rewriter) const {
LogicalResult matchAndRewriteCPU(enzymexla::GeqrfOp op,
PatternRewriter &rewriter) const {
auto ctx = op->getContext();
LLVMTypeConverter typeConverter(ctx);

Expand All @@ -80,7 +77,7 @@ struct GeqrfOpLowering : public OpRewritePattern<enzymexla::GeqrfOp> {
auto type_llvm_void = LLVM::LLVMVoidType::get(ctx);

std::string fn = "geqrf_";
if (auto prefix = lapack_precision_prefix(inputElementType)) {
if (auto prefix = lapackPrecisionPrefix(inputElementType)) {
fn = *prefix + fn;
} else {
op->emitOpError() << "Unsupported element type: " << inputElementType;
Expand Down Expand Up @@ -209,8 +206,8 @@ struct GeqrfOpLowering : public OpRewritePattern<enzymexla::GeqrfOp> {
return success();
}

LogicalResult matchAndRewrite_cuda(enzymexla::GeqrfOp op,
PatternRewriter &rewriter) const {
LogicalResult matchAndRewriteCUDA(enzymexla::GeqrfOp op,
PatternRewriter &rewriter) const {
auto ctx = op->getContext();
LLVMTypeConverter typeConverter(ctx);

Expand Down Expand Up @@ -265,8 +262,8 @@ struct GeqrfOpLowering : public OpRewritePattern<enzymexla::GeqrfOp> {
return success();
}

LogicalResult matchAndRewrite_tpu(enzymexla::GeqrfOp op,
PatternRewriter &rewriter) const {
LogicalResult matchAndRewriteTPU(enzymexla::GeqrfOp op,
PatternRewriter &rewriter) const {
auto ctx = op->getContext();
LLVMTypeConverter typeConverter(ctx);

Expand Down Expand Up @@ -316,23 +313,20 @@ struct GeqrtOpLowering : public OpRewritePattern<enzymexla::GeqrtOp> {
LogicalResult matchAndRewrite(enzymexla::GeqrtOp op,
PatternRewriter &rewriter) const override {
if (backend == "cpu")
return this->matchAndRewrite_cpu(op, rewriter);

return matchAndRewriteCPU(op, rewriter);
// else if (backend == "cuda")
// return this->matchAndRewrite_cuda(op, rewriter);

// return matchAndRewriteCUDA(op, rewriter);
// else if (backend == "tpu")
// return this->matchAndRewrite_tpu(op, rewriter);

// return matchAndRewriteTPU(op, rewriter);
else
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
"\"");
}

// TODO get matrix sizes dynamically so that we don't need to create a
// function wrapper for each op instance
LogicalResult matchAndRewrite_cpu(enzymexla::GeqrtOp op,
PatternRewriter &rewriter) const {
LogicalResult matchAndRewriteCPU(enzymexla::GeqrtOp op,
PatternRewriter &rewriter) const {
auto ctx = op->getContext();
LLVMTypeConverter typeConverter(ctx);

Expand All @@ -355,7 +349,7 @@ struct GeqrtOpLowering : public OpRewritePattern<enzymexla::GeqrtOp> {
auto type_llvm_void = LLVM::LLVMVoidType::get(ctx);

std::string fn = "geqrt_";
if (auto prefix = lapack_precision_prefix(inputElementType)) {
if (auto prefix = lapackPrecisionPrefix(inputElementType)) {
fn = *prefix + fn;
} else {
op->emitOpError() << "Unsupported element type: " << inputElementType;
Expand Down Expand Up @@ -523,23 +517,20 @@ struct OrgqrOpLowering : public OpRewritePattern<enzymexla::OrgqrOp> {
LogicalResult matchAndRewrite(enzymexla::OrgqrOp op,
PatternRewriter &rewriter) const override {
if (backend == "cpu")
return this->matchAndRewrite_cpu(op, rewriter);

return matchAndRewriteCPU(op, rewriter);
else if (backend == "cuda")
return this->matchAndRewrite_cuda(op, rewriter);

return matchAndRewriteCUDA(op, rewriter);
else if (backend == "tpu")
return this->matchAndRewrite_tpu(op, rewriter);

return matchAndRewriteTPU(op, rewriter);
else
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
"\"");
}

// TODO get matrix sizes dynamically so that we don't need to create a
// function wrapper for each op instance
LogicalResult matchAndRewrite_cpu(enzymexla::OrgqrOp op,
PatternRewriter &rewriter) const {
LogicalResult matchAndRewriteCPU(enzymexla::OrgqrOp op,
PatternRewriter &rewriter) const {
auto ctx = op->getContext();
LLVMTypeConverter typeConverter(ctx);

Expand Down Expand Up @@ -567,7 +558,7 @@ struct OrgqrOpLowering : public OpRewritePattern<enzymexla::OrgqrOp> {
auto type_llvm_void = LLVM::LLVMVoidType::get(ctx);

std::string fn = "gqr_";
if (auto prefix = lapack_precision_prefix(inputElementType)) {
if (auto prefix = lapackPrecisionPrefix(inputElementType)) {
if (prefix == "s" || prefix == "d")
fn = *prefix + "or" + fn;
else
Expand Down Expand Up @@ -688,8 +679,8 @@ struct OrgqrOpLowering : public OpRewritePattern<enzymexla::OrgqrOp> {
return success();
}

LogicalResult matchAndRewrite_cuda(enzymexla::OrgqrOp op,
PatternRewriter &rewriter) const {
LogicalResult matchAndRewriteCUDA(enzymexla::OrgqrOp op,
PatternRewriter &rewriter) const {
auto ctx = op->getContext();
LLVMTypeConverter typeConverter(ctx);

Expand Down Expand Up @@ -734,8 +725,8 @@ struct OrgqrOpLowering : public OpRewritePattern<enzymexla::OrgqrOp> {
return success();
}

LogicalResult matchAndRewrite_tpu(enzymexla::OrgqrOp op,
PatternRewriter &rewriter) const {
LogicalResult matchAndRewriteTPU(enzymexla::OrgqrOp op,
PatternRewriter &rewriter) const {
auto ctx = op->getContext();
LLVMTypeConverter typeConverter(ctx);

Expand Down Expand Up @@ -772,23 +763,20 @@ struct OrmqrOpLowering : public OpRewritePattern<enzymexla::OrmqrOp> {
LogicalResult matchAndRewrite(enzymexla::OrmqrOp op,
PatternRewriter &rewriter) const override {
if (backend == "cpu")
return this->matchAndRewrite_cpu(op, rewriter);

return matchAndRewriteCPU(op, rewriter);
// else if (backend == "cuda")
// return this->matchAndRewrite_cuda(op, rewriter);

// return matchAndRewriteCUDA(op, rewriter);
// else if (backend == "tpu")
// return this->matchAndRewrite_tpu(op, rewriter);

// return matchAndRewriteTPU(op, rewriter);
else
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
"\"");
}

// TODO get matrix sizes dynamically so that we don't need to create a
// function wrapper for each op instance
LogicalResult matchAndRewrite_cpu(enzymexla::OrmqrOp op,
PatternRewriter &rewriter) const {
LogicalResult matchAndRewriteCPU(enzymexla::OrmqrOp op,
PatternRewriter &rewriter) const {
auto ctx = op->getContext();
LLVMTypeConverter typeConverter(ctx);

Expand Down Expand Up @@ -873,7 +861,7 @@ struct OrmqrOpLowering : public OpRewritePattern<enzymexla::OrmqrOp> {
auto type_llvm_char = rewriter.getIntegerType(8);

std::string fn = "mqr_";
if (auto prefix = lapack_precision_prefix(A_eltype)) {
if (auto prefix = lapackPrecisionPrefix(A_eltype)) {
if (prefix == "s" || prefix == "d")
fn = *prefix + "or" + fn;
else
Expand Down Expand Up @@ -1031,23 +1019,20 @@ struct GemqrtOpLowering : public OpRewritePattern<enzymexla::GemqrtOp> {
LogicalResult matchAndRewrite(enzymexla::GemqrtOp op,
PatternRewriter &rewriter) const override {
if (backend == "cpu")
return this->matchAndRewrite_cpu(op, rewriter);

return matchAndRewriteCPU(op, rewriter);
// else if (backend == "cuda")
// return this->matchAndRewrite_cuda(op, rewriter);

// return matchAndRewriteCUDA(op, rewriter);
// else if (backend == "tpu")
// return this->matchAndRewrite_tpu(op, rewriter);

// return matchAndRewriteTPU(op, rewriter);
else
return rewriter.notifyMatchFailure(op, "Unknown backend: \"" + backend +
"\"");
}

// TODO get matrix sizes dynamically so that we don't need to create a
// function wrapper for each op instance
LogicalResult matchAndRewrite_cpu(enzymexla::GemqrtOp op,
PatternRewriter &rewriter) const {
LogicalResult matchAndRewriteCPU(enzymexla::GemqrtOp op,
PatternRewriter &rewriter) const {
auto ctx = op->getContext();
LLVMTypeConverter typeConverter(ctx);

Expand Down Expand Up @@ -1141,7 +1126,7 @@ struct GemqrtOpLowering : public OpRewritePattern<enzymexla::GemqrtOp> {
auto type_llvm_char = rewriter.getIntegerType(8);

std::string fn = "gemqrt_";
if (auto prefix = lapack_precision_prefix(C_eltype)) {
if (auto prefix = lapackPrecisionPrefix(C_eltype)) {
fn = *prefix + fn;
} else {
op->emitOpError() << "Unsupported element type: " << C_eltype;
Expand Down
Loading
Loading