Skip to content

Commit 3b4039c

Browse files
apaszkejax authors
authored andcommitted
[Mosaic GPU] Load LLVM lowering interfaces for all dialects
Apparently we were missing interface registration code for LLVM lowering, which the gpu-to-llvm pass gracefully ignores unless compiled with debug assertions enabled. But, simply adding the assertions in fact makes the pass _too powerful_ and makes it lower _all dialects to LLVM_, which is not what we want. That's why I've replaced it with a minimal version that is only repsponsible for handling the GPU dialect, making the lowering similar to the one prior to extra registrations. PiperOrigin-RevId: 641874183
1 parent 2ade7e7 commit 3b4039c

File tree

7 files changed

+217
-6
lines changed

7 files changed

+217
-6
lines changed

jaxlib/mosaic/gpu/BUILD

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,27 @@ py_library(
2828

2929
cc_library(
3030
name = "passes",
31-
srcs = ["launch_lowering.cc"],
32-
hdrs = ["launch_lowering.h"],
31+
srcs = [
32+
"launch_lowering.cc",
33+
"passes.cc",
34+
],
35+
hdrs = [
36+
"launch_lowering.h",
37+
"pass_boilerplate.h",
38+
"passes.h",
39+
],
3340
deps = [
3441
"@llvm-project//llvm:Support",
3542
"@llvm-project//mlir:DataLayoutInterfaces",
3643
"@llvm-project//mlir:FuncDialect",
3744
"@llvm-project//mlir:GPUDialect",
45+
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
3846
"@llvm-project//mlir:IR",
47+
"@llvm-project//mlir:LLVMCommonConversion",
3948
"@llvm-project//mlir:LLVMDialect",
4049
"@llvm-project//mlir:Pass",
4150
"@llvm-project//mlir:Support",
42-
"@com_google_absl//absl/log",
51+
"@llvm-project//mlir:TransformUtils",
4352
],
4453
)
4554

@@ -97,29 +106,38 @@ cc_library(
97106
":passes",
98107
"@llvm-project//llvm:Support",
99108
"@llvm-project//mlir:ArithDialect",
109+
"@llvm-project//mlir:ArithToLLVM",
100110
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
111+
"@llvm-project//mlir:ComplexToLLVM",
112+
"@llvm-project//mlir:ControlFlowToLLVM",
101113
"@llvm-project//mlir:ConversionPasses",
102114
"@llvm-project//mlir:ExecutionEngine",
103115
"@llvm-project//mlir:ExecutionEngineUtils",
104116
"@llvm-project//mlir:FuncDialect",
117+
"@llvm-project//mlir:FuncToLLVM",
105118
"@llvm-project//mlir:GPUDialect",
106119
"@llvm-project//mlir:GPUToLLVMIRTranslation",
107120
"@llvm-project//mlir:GPUTransforms",
108121
"@llvm-project//mlir:IR",
122+
"@llvm-project//mlir:IndexToLLVM",
109123
"@llvm-project//mlir:LLVMDialect",
110124
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
111125
"@llvm-project//mlir:MathDialect",
126+
"@llvm-project//mlir:MathToLLVM",
112127
"@llvm-project//mlir:MemRefDialect",
128+
"@llvm-project//mlir:MemRefToLLVM",
113129
"@llvm-project//mlir:MemRefTransforms",
114130
"@llvm-project//mlir:NVGPUDialect",
115131
"@llvm-project//mlir:NVVMDialect",
116132
"@llvm-project//mlir:NVVMTarget",
133+
"@llvm-project//mlir:NVVMToLLVM",
117134
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
118135
"@llvm-project//mlir:Parser",
119136
"@llvm-project//mlir:Pass",
120137
"@llvm-project//mlir:SCFDialect",
121138
"@llvm-project//mlir:Support",
122139
"@llvm-project//mlir:Transforms",
140+
"@llvm-project//mlir:UBToLLVM",
123141
"@llvm-project//mlir:VectorDialect",
124142
"@xla//xla/service:custom_call_status",
125143
"@xla//xla/service:custom_call_target_registry",

jaxlib/mosaic/gpu/custom_call.cc

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,16 @@ limitations under the License.
3737
#include "llvm/include/llvm/ADT/SmallVector.h"
3838
#include "llvm/include/llvm/Support/CodeGen.h"
3939
#include "llvm/include/llvm/Support/TargetSelect.h"
40+
#include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
41+
#include "mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
42+
#include "mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
43+
#include "mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
44+
#include "mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
45+
#include "mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h"
46+
#include "mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
47+
#include "mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
4048
#include "mlir/include/mlir/Conversion/Passes.h"
49+
#include "mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h"
4150
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
4251
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
4352
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -67,6 +76,7 @@ limitations under the License.
6776
#include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
6877
#include "mlir/include/mlir/Transforms/Passes.h"
6978
#include "jaxlib/mosaic/gpu/launch_lowering.h"
79+
#include "jaxlib/mosaic/gpu/passes.h"
7080
#include "xla/service/custom_call_status.h"
7181
#include "xla/service/custom_call_target_registry.h"
7282

@@ -100,6 +110,7 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
100110
mlir::memref::registerMemRefPasses();
101111
mlir::registerGPUPasses();
102112
mosaic::gpu::registerGpuLaunchLoweringPass();
113+
mosaic::gpu::registerConvertGpuToLLVMPass();
103114
return true;
104115
}();
105116
(void)register_once;
@@ -123,7 +134,7 @@ mlir::FailureOr<mlir::OpPassManager> GetPassPipeline(
123134
gpu.module(canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true}),
124135
gpu.module(cse),
125136
gpu.module(reconcile-unrealized-casts),
126-
gpu-to-llvm{gpu-binary-annotation=gpu.binary use-bare-pointers-for-host=false use-bare-pointers-for-kernels=false},
137+
mosaic-convert-gpu-to-llvm,
127138
gpu-module-to-binary{format=)" +
128139
mlir::gpu::stringifyCompilationTarget(target).str() + R"(},
129140
convert-math-to-llvm{approximate-log1p=true},
@@ -152,6 +163,16 @@ void InitContext(mlir::MLIRContext* context) {
152163
mlir::scf::SCFDialect, mlir::vector::VectorDialect,
153164
mlir::gpu::GPUDialect, mlir::nvgpu::NVGPUDialect,
154165
mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect>();
166+
mlir::registerConvertNVVMToLLVMInterface(registry);
167+
mlir::registerConvertComplexToLLVMInterface(registry);
168+
mlir::registerConvertMemRefToLLVMInterface(registry);
169+
mlir::registerConvertMathToLLVMInterface(registry);
170+
mlir::registerConvertFuncToLLVMInterface(registry);
171+
mlir::index::registerConvertIndexToLLVMInterface(registry);
172+
mlir::cf::registerConvertControlFlowToLLVMInterface(registry);
173+
mlir::ub::registerConvertUBToLLVMInterface(registry); // Arith needs this
174+
mlir::arith::registerConvertArithToLLVMInterface(registry);
175+
mlir::registerFinalizeMemRefToLLVMConversionPass();
155176
mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry);
156177
mlir::NVVM::registerNVVMTargetInterfaceExternalModels(registry);
157178
mlir::registerBuiltinDialectTranslation(registry);

jaxlib/mosaic/gpu/launch_lowering.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ void buildInitFunction(mlir::OpBuilder &module_builder,
171171
used_smem = builder.create<mlir::LLVM::ConstantOp>(
172172
loc, i32,
173173
builder.getI32IntegerAttr(
174-
mlir::cast<mlir::IntegerAttr>(const_smem.getValue()).getSInt()));
174+
mlir::cast<mlir::IntegerAttr>(const_smem.getValue()).getInt()));
175175
}
176176
}
177177
mlir::Value kernel_handle =
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/* Copyright 2024 The JAX Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_
17+
#define JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_
18+
19+
#include "mlir/include/mlir/IR/DialectRegistry.h"
20+
#include "mlir/include/mlir/Pass/Pass.h"
21+
#include "mlir/include/mlir/Support/LLVM.h"
22+
#include "mlir/include/mlir/Support/TypeID.h"
23+
namespace mosaic {
24+
namespace gpu {
25+
26+
template <typename Derived, typename Op = void>
27+
class Pass : public ::mlir::OperationPass<Op> {
28+
public:
29+
Pass() : ::mlir::OperationPass<Op>(::mlir::TypeID::get<Derived>()) {}
30+
Pass(const Pass &other) : ::mlir::OperationPass<Op>(other) {}
31+
Pass &operator=(const Pass &) = delete;
32+
Pass(Pass &&) = delete;
33+
Pass &operator=(Pass &&) = delete;
34+
~Pass() = default;
35+
36+
static constexpr ::llvm::StringLiteral getArgumentName() {
37+
return ::llvm::StringLiteral(Derived::kArgumentName);
38+
}
39+
::llvm::StringRef getArgument() const override { return getArgumentName(); }
40+
::llvm::StringRef getDescription() const override { return ""; }
41+
static constexpr ::llvm::StringLiteral getPassName() {
42+
return ::llvm::StringLiteral(Derived::kPassName);
43+
}
44+
::llvm::StringRef getName() const override { return getPassName(); }
45+
static bool classof(const ::mlir::Pass *pass) {
46+
return pass->getTypeID() == ::mlir::TypeID::get<Derived>();
47+
}
48+
std::unique_ptr<::mlir::Pass> clonePass() const override {
49+
return std::make_unique<Derived>(*static_cast<const Derived *>(this));
50+
}
51+
void getDependentDialects(::mlir::DialectRegistry &registry) const override {}
52+
53+
private:
54+
using This =
55+
Pass<Derived, Op>; // Can't have a comma in the macro instantiation
56+
57+
public:
58+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(This)
59+
};
60+
61+
} // namespace gpu
62+
} // namespace mosaic
63+
64+
#endif // JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_

jaxlib/mosaic/gpu/passes.cc

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/* Copyright 2024 The JAX Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "jaxlib/mosaic/gpu/passes.h"
17+
#include <memory>
18+
#include <utility>
19+
#include <vector>
20+
21+
#include "llvm/include/llvm/ADT/StringRef.h"
22+
#include "llvm/include/llvm/Support/Debug.h"
23+
#include "mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h"
24+
#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h"
25+
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
26+
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
27+
#include "mlir/include/mlir/IR/BuiltinOps.h"
28+
#include "mlir/include/mlir/IR/SymbolTable.h"
29+
#include "mlir/include/mlir/Pass/PassRegistry.h"
30+
#include "mlir/include/mlir/Support/LLVM.h"
31+
#include "mlir/include/mlir/Transforms/DialectConversion.h"
32+
#include "jaxlib/mosaic/gpu/pass_boilerplate.h"
33+
34+
namespace mosaic {
35+
namespace gpu {
36+
37+
namespace {
38+
39+
class ConvertGpuToLLVMPass
40+
: public mosaic::gpu::Pass<ConvertGpuToLLVMPass, mlir::ModuleOp> {
41+
public:
42+
using mosaic::gpu::Pass<ConvertGpuToLLVMPass, mlir::ModuleOp>::Pass;
43+
static constexpr llvm::StringLiteral kArgumentName =
44+
"mosaic-convert-gpu-to-llvm";
45+
static constexpr llvm::StringLiteral kPassName = "ConvertGpuToLLVMPass";
46+
47+
void runOnOperation() override {
48+
llvm::DebugFlag = true;
49+
mlir::MLIRContext *ctx = &getContext();
50+
mlir::RewritePatternSet patterns(ctx);
51+
mlir::LLVMTypeConverter converter(ctx);
52+
mlir::ConversionTarget target(*ctx);
53+
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
54+
target.addLegalOp<mlir::gpu::GPUModuleOp>();
55+
target.addDynamicallyLegalOp<mlir::gpu::LaunchFuncOp>(
56+
[&](mlir::gpu::LaunchFuncOp op) -> bool {
57+
return converter.isLegal(op->getOperandTypes()) &&
58+
converter.isLegal(op->getResultTypes());
59+
});
60+
auto symtab = mlir::SymbolTable(getOperation());
61+
mlir::populateGpuToLLVMConversionPatterns(converter, patterns, "gpu.binary",
62+
false, &symtab);
63+
if (mlir::applyPartialConversion(getOperation(), target,
64+
std::move(patterns))
65+
.failed()) {
66+
signalPassFailure();
67+
}
68+
llvm::DebugFlag = false;
69+
}
70+
};
71+
72+
} // namespace
73+
74+
void registerConvertGpuToLLVMPass() {
75+
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
76+
return std::make_unique<ConvertGpuToLLVMPass>();
77+
});
78+
}
79+
80+
} // namespace gpu
81+
} // namespace mosaic

jaxlib/mosaic/gpu/passes.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/* Copyright 2024 The JAX Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_
17+
#define JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_
18+
19+
namespace mosaic {
20+
namespace gpu {
21+
22+
void registerConvertGpuToLLVMPass();
23+
24+
} // namespace gpu
25+
} // namespace mosaic
26+
27+
#endif // JAXLIB_MOSAIC_GPU_CONVERT_GPU_TO_LLVM_H_

tests/mosaic/gpu_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def setUp(self):
170170

171171
class TestUtilTest(TestCase):
172172

173-
def test_copy(self):
173+
def test_copy_basic(self):
174174
def kernel(ctx, src, dst, _):
175175
copy(src, dst)
176176
x = jnp.arange(2 * 3 * 5).reshape(2, 5, 3)

0 commit comments

Comments
 (0)