Skip to content

Commit 7bd81db

Browse files
committed
[Mosaic GPU] Improve default kernel name and add option to customize
This allows users to distinguish Mosaic GPU kernels from other kernels when using profiling programs such as Nsight Systems. The new default behavior is to use `mosaic_gpu_<def_name>_kernel` as the kernel name, where `<def_name>` is the name of the Mosaic GPU Python kernel function passed to `as_gpu_kernel` or `as_torch_gpu_kernel`. We also add a new `kernel_name` optional argument to `as_gpu_kernel` and `as_torch_gpu_kernel`. If `kernel_name` is not `None`, the resulting kernel name is `mosaic_gpu_<kernel_name>_kernel`. This is useful when the Mosaic GPU Python kernel function is constructed through metaprogramming so that the final specialized kernel can have different meaningful names depending on the metaparameters. Previously the kernel name was always `main_kernel`.
1 parent f2f02ee commit 7bd81db

File tree

3 files changed

+55
-7
lines changed

3 files changed

+55
-7
lines changed

jax/experimental/mosaic/gpu/core.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,7 @@ def _lower_as_gpu_kernel(
847847
out_shape,
848848
smem_scratch_shape: ShapeTree | Union[ShapeTree],
849849
module_name: str,
850+
kernel_name: str | None = None,
850851
prof_spec: profiler.ProfilerSpec | None = None,
851852
):
852853
ptr_ty = ir.Type.parse("!llvm.ptr")
@@ -873,6 +874,8 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType:
873874
module = ir.Module.create()
874875
attrs = module.operation.attributes
875876
attrs["sym_name"] = ir.StringAttr.get(module_name)
877+
if kernel_name is None:
878+
kernel_name = getattr(body, "__name__", "anonymous")
876879
with ir.InsertionPoint(module.body):
877880
_declare_runtime_functions()
878881
gmem_scratch_bytes = 0
@@ -882,7 +885,7 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType:
882885
ir.Attribute.parse("#llvm.linkage<external>"),
883886
addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory.
884887
)
885-
@func.FuncOp.from_py_func(ptr_ty, ptr_ty)
888+
@func.FuncOp.from_py_func(ptr_ty, ptr_ty, name=f"mosaic_gpu_{kernel_name}")
886889
def main(token_ptr, buffers):
887890
nonlocal gmem_scratch_bytes
888891
token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
@@ -947,6 +950,7 @@ def as_gpu_kernel(
947950
prof_spec: profiler.ProfilerSpec | None = None,
948951
cluster: tuple[int, int, int] = (1, 1, 1),
949952
module_name: str = "unknown",
953+
kernel_name: str | None = None,
950954
):
951955
if isinstance(in_shape, list):
952956
in_shape = tuple(in_shape)
@@ -956,7 +960,7 @@ def as_gpu_kernel(
956960
module, out_shape, unwrap_output_tuple = (
957961
_lower_as_gpu_kernel(
958962
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
959-
module_name, prof_spec
963+
module_name, kernel_name, prof_spec
960964
)
961965
)
962966

@@ -1014,6 +1018,7 @@ def as_torch_gpu_kernel(
10141018
prof_spec: profiler.ProfilerSpec | None = None,
10151019
cluster: tuple[int, int, int] = (1, 1, 1),
10161020
module_name: str = "unknown",
1021+
kernel_name: str | None = None,
10171022
):
10181023
try:
10191024
import torch
@@ -1032,7 +1037,7 @@ def as_torch_gpu_kernel(
10321037
module, out_shape, unwrap_output_tuple = (
10331038
_lower_as_gpu_kernel(
10341039
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
1035-
module_name, prof_spec
1040+
module_name, kernel_name, prof_spec
10361041
)
10371042
)
10381043

jaxlib/mosaic/gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ cc_library(
126126
"@com_google_absl//absl/status",
127127
"@com_google_absl//absl/status:statusor",
128128
"@com_google_absl//absl/strings",
129+
"@com_google_absl//absl/strings:str_format",
129130
"@com_google_absl//absl/synchronization",
130131
"@llvm-project//llvm:Support",
131132
"@llvm-project//mlir:ArithDialect",

jaxlib/mosaic/gpu/custom_call.cc

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ limitations under the License.
3838
#include "absl/status/status.h"
3939
#include "absl/status/statusor.h"
4040
#include "absl/strings/str_cat.h"
41+
#include "absl/strings/str_format.h"
4142
#include "absl/synchronization/mutex.h"
4243
#include "llvm/include/llvm/ADT/SmallVector.h"
4344
#include "llvm/include/llvm/Support/CodeGen.h"
@@ -415,6 +416,40 @@ GetKernelCache() {
415416
return std::make_pair(&context_cache, &mutex);
416417
}
417418

419+
absl::StatusOr<std::pair<std::string, std::string>> GetHostAndInitFuncNames(
420+
mlir::ModuleOp module_op) {
421+
// We look for two top level C-interface functions:
422+
// - "host" function with symbol name "_mlir_ciface_<foo>"
423+
// - "init" function with symbol name "_mlir_ciface_<foo>_init"
424+
constexpr std::string_view prefix = "_mlir_ciface_";
425+
std::vector<std::string> names;
426+
for (mlir::LLVM::LLVMFuncOp llvm_func :
427+
module_op.getOps<mlir::LLVM::LLVMFuncOp>()) {
428+
if (llvm_func.getName().starts_with(prefix)) {
429+
names.push_back(llvm_func.getName().str());
430+
}
431+
}
432+
if (auto size = names.size(); size != 2) {
433+
return absl::InternalError(absl::StrFormat(
434+
"Expected to locate 2 symbols with %s prefix in the MLIR module, found "
435+
"%d instead.",
436+
prefix, size));
437+
}
438+
// _mlir_ciface_<foo>_init now follows _mlir_ciface_<foo>
439+
std::sort(names.begin(), names.end());
440+
441+
std::string host_func_name = names[0];
442+
std::string init_func_name = names[1];
443+
444+
if (init_func_name != absl::StrCat(host_func_name, "_init")) {
445+
return absl::InternalError(absl::StrFormat(
446+
"Expected init function name to equal the concatenation of the host "
447+
"function name and the \"_init\" suffix, instead got "
448+
"init_func_name=%s, host_func_name=%s.",
449+
init_func_name, host_func_name));
450+
}
451+
return std::make_pair(host_func_name, init_func_name);
452+
}
418453

419454
absl::StatusOr<CompiledKernel> CompileAndInit(const char* module) {
420455
mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED);
@@ -430,9 +465,16 @@ absl::StatusOr<CompiledKernel> CompileAndInit(const char* module) {
430465
return maybe_engine.status();
431466
}
432467
mlir::ExecutionEngine* execution_engine = maybe_engine->get();
433-
auto main = execution_engine->lookupPacked("_mlir_ciface_main");
434-
auto init = execution_engine->lookupPacked("_mlir_ciface_main_init");
435-
if (!init || !main) {
468+
469+
auto host_and_init_func_names = GetHostAndInitFuncNames(*module_op);
470+
if (!host_and_init_func_names.ok()) {
471+
return host_and_init_func_names.status();
472+
}
473+
auto [host_name, init_name] = host_and_init_func_names.value();
474+
475+
auto host = execution_engine->lookupPacked(host_name);
476+
auto init = execution_engine->lookupPacked(init_name);
477+
if (!init || !host) {
436478
return absl::InternalError("Failed to retrieve kernel function");
437479
}
438480
void* module_ptr = nullptr;
@@ -442,7 +484,7 @@ absl::StatusOr<CompiledKernel> CompileAndInit(const char* module) {
442484
void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr};
443485
reinterpret_cast<MosaicInitFunc*>(*init)(init_args);
444486
return CompiledKernel(std::move(*maybe_engine), kernel_ptr,
445-
reinterpret_cast<MosaicHostFunc*>(*main));
487+
reinterpret_cast<MosaicHostFunc*>(*host));
446488
}
447489

448490
// Each compiled kernel has a unique init func, and each kernel is used from

0 commit comments

Comments
 (0)