Skip to content

Commit fa6585d

Browse files
Merge pull request jax-ml#25006 from andportnoy:aportnoy/mosaic-gpu-kernel-custom-name
PiperOrigin-RevId: 702772768
2 parents 1da0379 + 7bd81db commit fa6585d

File tree

3 files changed

+55
-7
lines changed

3 files changed

+55
-7
lines changed

jax/experimental/mosaic/gpu/core.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,7 @@ def _lower_as_gpu_kernel(
847847
out_shape,
848848
smem_scratch_shape: ShapeTree | Union[ShapeTree],
849849
module_name: str,
850+
kernel_name: str | None = None,
850851
prof_spec: profiler.ProfilerSpec | None = None,
851852
):
852853
ptr_ty = ir.Type.parse("!llvm.ptr")
@@ -873,6 +874,8 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType:
873874
module = ir.Module.create()
874875
attrs = module.operation.attributes
875876
attrs["sym_name"] = ir.StringAttr.get(module_name)
877+
if kernel_name is None:
878+
kernel_name = getattr(body, "__name__", "anonymous")
876879
with ir.InsertionPoint(module.body):
877880
_declare_runtime_functions()
878881
gmem_scratch_bytes = 0
@@ -882,7 +885,7 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType:
882885
ir.Attribute.parse("#llvm.linkage<external>"),
883886
addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory.
884887
)
885-
@func.FuncOp.from_py_func(ptr_ty, ptr_ty)
888+
@func.FuncOp.from_py_func(ptr_ty, ptr_ty, name=f"mosaic_gpu_{kernel_name}")
886889
def main(token_ptr, buffers):
887890
nonlocal gmem_scratch_bytes
888891
token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
@@ -947,6 +950,7 @@ def as_gpu_kernel(
947950
prof_spec: profiler.ProfilerSpec | None = None,
948951
cluster: tuple[int, int, int] = (1, 1, 1),
949952
module_name: str = "unknown",
953+
kernel_name: str | None = None,
950954
):
951955
if isinstance(in_shape, list):
952956
in_shape = tuple(in_shape)
@@ -956,7 +960,7 @@ def as_gpu_kernel(
956960
module, out_shape, unwrap_output_tuple = (
957961
_lower_as_gpu_kernel(
958962
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
959-
module_name, prof_spec
963+
module_name, kernel_name, prof_spec
960964
)
961965
)
962966

@@ -1014,6 +1018,7 @@ def as_torch_gpu_kernel(
10141018
prof_spec: profiler.ProfilerSpec | None = None,
10151019
cluster: tuple[int, int, int] = (1, 1, 1),
10161020
module_name: str = "unknown",
1021+
kernel_name: str | None = None,
10171022
):
10181023
try:
10191024
import torch
@@ -1032,7 +1037,7 @@ def as_torch_gpu_kernel(
10321037
module, out_shape, unwrap_output_tuple = (
10331038
_lower_as_gpu_kernel(
10341039
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
1035-
module_name, prof_spec
1040+
module_name, kernel_name, prof_spec
10361041
)
10371042
)
10381043

jaxlib/mosaic/gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ cc_library(
126126
"@com_google_absl//absl/status",
127127
"@com_google_absl//absl/status:statusor",
128128
"@com_google_absl//absl/strings",
129+
"@com_google_absl//absl/strings:str_format",
129130
"@com_google_absl//absl/synchronization",
130131
"@llvm-project//llvm:Support",
131132
"@llvm-project//mlir:ArithDialect",

jaxlib/mosaic/gpu/custom_call.cc

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ limitations under the License.
3838
#include "absl/status/status.h"
3939
#include "absl/status/statusor.h"
4040
#include "absl/strings/str_cat.h"
41+
#include "absl/strings/str_format.h"
4142
#include "absl/synchronization/mutex.h"
4243
#include "llvm/include/llvm/ADT/SmallVector.h"
4344
#include "llvm/include/llvm/Support/CodeGen.h"
@@ -415,6 +416,40 @@ GetKernelCache() {
415416
return std::make_pair(&context_cache, &mutex);
416417
}
417418

419+
absl::StatusOr<std::pair<std::string, std::string>> GetHostAndInitFuncNames(
420+
mlir::ModuleOp module_op) {
421+
// We look for two top level C-interface functions:
422+
// - "host" function with symbol name "_mlir_ciface_<foo>"
423+
// - "init" function with symbol name "_mlir_ciface_<foo>_init"
424+
constexpr std::string_view prefix = "_mlir_ciface_";
425+
std::vector<std::string> names;
426+
for (mlir::LLVM::LLVMFuncOp llvm_func :
427+
module_op.getOps<mlir::LLVM::LLVMFuncOp>()) {
428+
if (llvm_func.getName().starts_with(prefix)) {
429+
names.push_back(llvm_func.getName().str());
430+
}
431+
}
432+
if (auto size = names.size(); size != 2) {
433+
return absl::InternalError(absl::StrFormat(
434+
"Expected to locate 2 symbols with %s prefix in the MLIR module, found "
435+
"%d instead.",
436+
prefix, size));
437+
}
438+
// _mlir_ciface_<foo>_init now follows _mlir_ciface_<foo>
439+
std::sort(names.begin(), names.end());
440+
441+
std::string host_func_name = names[0];
442+
std::string init_func_name = names[1];
443+
444+
if (init_func_name != absl::StrCat(host_func_name, "_init")) {
445+
return absl::InternalError(absl::StrFormat(
446+
"Expected init function name to equal the concatenation of the host "
447+
"function name and the \"_init\" suffix, instead got "
448+
"init_func_name=%s, host_func_name=%s.",
449+
init_func_name, host_func_name));
450+
}
451+
return std::make_pair(host_func_name, init_func_name);
452+
}
418453

419454
absl::StatusOr<CompiledKernel> CompileAndInit(const char* module) {
420455
mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED);
@@ -430,9 +465,16 @@ absl::StatusOr<CompiledKernel> CompileAndInit(const char* module) {
430465
return maybe_engine.status();
431466
}
432467
mlir::ExecutionEngine* execution_engine = maybe_engine->get();
433-
auto main = execution_engine->lookupPacked("_mlir_ciface_main");
434-
auto init = execution_engine->lookupPacked("_mlir_ciface_main_init");
435-
if (!init || !main) {
468+
469+
auto host_and_init_func_names = GetHostAndInitFuncNames(*module_op);
470+
if (!host_and_init_func_names.ok()) {
471+
return host_and_init_func_names.status();
472+
}
473+
auto [host_name, init_name] = host_and_init_func_names.value();
474+
475+
auto host = execution_engine->lookupPacked(host_name);
476+
auto init = execution_engine->lookupPacked(init_name);
477+
if (!init || !host) {
436478
return absl::InternalError("Failed to retrieve kernel function");
437479
}
438480
void* module_ptr = nullptr;
@@ -442,7 +484,7 @@ absl::StatusOr<CompiledKernel> CompileAndInit(const char* module) {
442484
void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr};
443485
reinterpret_cast<MosaicInitFunc*>(*init)(init_args);
444486
return CompiledKernel(std::move(*maybe_engine), kernel_ptr,
445-
reinterpret_cast<MosaicHostFunc*>(*main));
487+
reinterpret_cast<MosaicHostFunc*>(*host));
446488
}
447489

448490
// Each compiled kernel has a unique init func, and each kernel is used from

0 commit comments

Comments
 (0)