Skip to content

Commit cd480ba

Browse files
authored
feat: triton c api (#1467)
* feat: triton c api * feat: mosaic gpu build
1 parent a46511e commit cd480ba

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

deps/ReactantExtra/BUILD

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,24 @@ gentbl_cc_library(
13121312
],
13131313
)
13141314

1315+
gentbl_cc_library(
1316+
name = "MosaicGPUJLIncGen",
1317+
tbl_outs = [
1318+
(
1319+
[
1320+
"--generator=jl-op-defs",
1321+
"--disable-module-wrap=0",
1322+
],
1323+
"MosaicGPU.jl",
1324+
),
1325+
],
1326+
tblgen = "//:mlir-jl-tblgen",
1327+
td_file = "@jax//jaxlib/mosaic/dialect/gpu:mosaic_gpu.td",
1328+
deps = [
1329+
"@jax//jaxlib/mosaic/dialect/gpu:mosaic_gpu_td_files",
1330+
],
1331+
)
1332+
13151333
gentbl_cc_library(
13161334
name = "TritonJLIncGen",
13171335
tbl_outs = [
@@ -1442,14 +1460,19 @@ genrule(
14421460
"@stablehlo//:stablehlo/integrations/c/StablehloDialectApi.h",
14431461
"@stablehlo//:stablehlo/integrations/c/StablehloTypes.h",
14441462
"@shardy//shardy/integrations/c:attributes.h",
1463+
"@jax//jaxlib/triton:triton_dialect_capi.h",
1464+
"@jax//jaxlib/mosaic:dialect/tpu/integrations/c/tpu_dialect.h",
1465+
"@jax//jaxlib/mosaic:dialect/tpu/integrations/c/tpu_passes.capi.h.inc",
1466+
"@jax//jaxlib/mosaic/dialect/gpu:integrations/c/attributes.h",
1467+
"@jax//jaxlib/mosaic/dialect/gpu:integrations/c/gpu_dialect.h",
14451468
"//:Project.toml",
14461469
"//:Manifest.toml",
14471470
"//:wrap.toml",
14481471
"//:missing_defs.jl",
14491472
"//:make.jl",
14501473
],
14511474
outs = ["libMLIR_h.jl"],
1452-
cmd = "$$JULIA \"--color=yes\" \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$(location @shardy//shardy/integrations/c:attributes.h)\" \"$@\"",
1475+
cmd = "$$JULIA \"--color=yes\" \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$(location @shardy//shardy/integrations/c:attributes.h)\" \"$(location @jax//jaxlib/triton:triton_dialect_capi.h)\" \"$(location @jax//jaxlib/mosaic:dialect/tpu/integrations/c/tpu_dialect.h)\" \"$(location @jax//jaxlib/mosaic/dialect/gpu:integrations/c/gpu_dialect.h)\" \"$@\"",
14531476
tags = [
14541477
"jlrule",
14551478
],

deps/ReactantExtra/make-bindings.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ for file in [
3636
"Gpu.jl",
3737
"Affine.jl",
3838
"TPU.jl",
39+
"MosaicGPU.jl",
3940
"Triton.jl",
4041
"Shardy.jl",
4142
"MPI.jl",

deps/ReactantExtra/make.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ let options = deepcopy(options)
1919
genarg = first(eachsplit(ARGS[3], " "))
2020

2121
gen_include_dir = joinpath(splitpath(genarg)[1:(end - 4)]...)
22-
23-
hlo_include_dir = joinpath(splitpath(ARGS[end - 2])[1:(end - 1)]...)
24-
25-
sdy_include_dir = joinpath(splitpath(ARGS[end - 1])[1:(end - 1)]...)
22+
hlo_include_dir = joinpath(splitpath(ARGS[end - 5])[1:(end - 1)]...)
23+
sdy_include_dir = joinpath(splitpath(ARGS[end - 4])[1:(end - 1)]...)
24+
triton_include_dir = joinpath(splitpath(ARGS[end - 3])[1:(end - 1)]...)
25+
mosaic_tpu_include_dir = joinpath(splitpath(ARGS[end - 2])[1:(end - 1)]...)
26+
mosaic_gpu_include_dir = joinpath(splitpath(ARGS[end - 1])[1:(end - 1)]...)
2627

2728
append!(
2829
args,
@@ -37,6 +38,12 @@ let options = deepcopy(options)
3738
hlo_include_dir,
3839
"-I",
3940
sdy_include_dir,
41+
"-I",
42+
triton_include_dir,
43+
"-I",
44+
mosaic_tpu_include_dir,
45+
"-I",
46+
mosaic_gpu_include_dir,
4047
"-x",
4148
"c++",
4249
],
@@ -46,6 +53,9 @@ let options = deepcopy(options)
4653
detect_headers(include_dir, args, Dict(), endswith("Python/Interop.h"))...,
4754
detect_headers(hlo_include_dir, args, Dict())...,
4855
detect_headers(sdy_include_dir, args, Dict())...,
56+
detect_headers(triton_include_dir, args, Dict())...,
57+
detect_headers(mosaic_tpu_include_dir, args, Dict())...,
58+
detect_headers(mosaic_gpu_include_dir, args, Dict())...,
4959
]
5060

5161
ctx = create_context(headers, args, options)

0 commit comments

Comments
 (0)