Skip to content

Commit dadc568

Browse files
wsmosesWilliam Moses
andauthored
WIP: kernels (#314)
* WIP: kernels * more files * fix * wip * wqtmp * wip * inc * continuing * wip * more work * inf rec * fix * overload working * continuing * continuing * push * conversion * continuing * fix * fix * host and device IR * Restore testing --------- Co-authored-by: William Moses <[email protected]>
1 parent 8688b5b commit dadc568

File tree

7 files changed

+539
-83
lines changed

7 files changed

+539
-83
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Scratch = "6c6a2e73-6563-6170-7368-637461726353"
2121
[weakdeps]
2222
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
2323
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
24+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2425
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
2526
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2627
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
@@ -31,6 +32,7 @@ path = "lib/ReactantCore"
3132
[extensions]
3233
ReactantAbstractFFTsExt = "AbstractFFTs"
3334
ReactantArrayInterfaceExt = "ArrayInterface"
35+
ReactantCUDAExt = "CUDA"
3436
ReactantNNlibExt = "NNlib"
3537
ReactantStatisticsExt = "Statistics"
3638
ReactantYaoBlocksExt = "YaoBlocks"
@@ -58,4 +60,5 @@ julia = "1.10"
5860
[extras]
5961
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
6062
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
63+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
6164
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

deps/ReactantExtra/API.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,10 @@ extern "C" void RegisterDialects(MlirContext cctx) {
470470
context.loadDialect<mlir::stablehlo::StablehloDialect>();
471471
context.loadDialect<mlir::chlo::ChloDialect>();
472472
}
473+
474+
#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
475+
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
476+
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
473477
extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
474478
mlir::DialectRegistry &registry = *unwrap(creg);
475479

@@ -513,6 +517,11 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
513517
mlir::affine::registerAffinePasses();
514518
mlir::registerReconcileUnrealizedCasts();
515519

520+
mlir::registerLLVMDialectImport(registry);
521+
mlir::registerNVVMDialectImport(registry);
522+
523+
mlir::LLVM::registerInlinerInterface(registry);
524+
516525
/*
517526
registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
518527
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
@@ -540,6 +549,81 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
540549
mlir::enzyme::registerEnzymeJaxTransformExtension(registry);
541550
}
542551

552+
553+
/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
554+
/// suffix in `lastUsedID`.
555+
static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName,
556+
unsigned &lastUsedID,
557+
mlir::ModuleOp source,
558+
mlir::ModuleOp target) {
559+
using namespace llvm;
560+
using namespace mlir;
561+
SmallString<64> newSymName(oldSymName);
562+
newSymName.push_back('_');
563+
while (true) {
564+
auto possible = newSymName + Twine(++lastUsedID);
565+
if (!SymbolTable::lookupSymbolIn(source, possible.str()) && !SymbolTable::lookupSymbolIn(target, possible.str())) {
566+
return StringAttr::get(target.getContext(), possible);
567+
}
568+
}
569+
}
570+
571+
572+
/// Checks if a symbol with the same name as `op` already exists in `source`.
573+
/// If so, renames `op` and updates all its references in `target`.
574+
static mlir::LogicalResult
575+
updateSymbolAndAllUses(mlir::SymbolOpInterface op, mlir::ModuleOp source, mlir::ModuleOp target,
576+
unsigned &lastUsedID) {
577+
using namespace llvm;
578+
using namespace mlir;
579+
580+
auto opName = op.getName().str();
581+
582+
if (!SymbolTable::lookupSymbolIn(target, opName)) {
583+
return success();
584+
}
585+
586+
StringAttr newSymName =
587+
renameSymbol(opName, lastUsedID, source, target);
588+
589+
if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, source)))
590+
return op.emitError("unable to update all symbol uses for ")
591+
<< opName << " to " << newSymName;
592+
593+
SymbolTable::setSymbolName(op, newSymName);
594+
return success();
595+
}
596+
597+
extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, const char* entryfn) {
598+
auto prevMod = cast<ModuleOp>(*unwrap(prevModC));
599+
auto newMod = cast<ModuleOp>(*unwrap(newModC));
600+
601+
Operation* entryFn = nullptr;
602+
603+
unsigned lastUsedID = 0;
604+
605+
for (auto &op : *newMod.getBody()) {
606+
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
607+
if (!symbolOp)
608+
continue;
609+
610+
StringRef oldSymName = symbolOp.getName();
611+
612+
if (oldSymName == entryfn) {
613+
entryFn = &op;
614+
}
615+
616+
if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod,
617+
lastUsedID))) {
618+
assert(0 && "failed to update all uses");
619+
}
620+
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
621+
}
622+
prevMod.getBody()->getOperations().splice(prevMod.getBody()->getOperations().end(),
623+
newMod.getBody()->getOperations());
624+
return wrap(entryFn);
625+
}
626+
543627
#pragma region xla::ifrt
544628

545629
#pragma region xla::ifrt::Value

deps/ReactantExtra/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ cc_library(
416416
"-Wl,-exported_symbol,_BufferToHost",
417417
"-Wl,-exported_symbol,_FreeClient",
418418
"-Wl,-exported_symbol,_ClientCompile",
419+
"-Wl,-exported_symbol,_LinkInModule",
419420
"-Wl,-exported_symbol,_FreeFuture",
420421
"-Wl,-exported_symbol,_FutureIsReady",
421422
"-Wl,-exported_symbol,_FutureAwait",
@@ -451,6 +452,10 @@ cc_library(
451452
"@llvm-project//mlir:TransformDialect",
452453
"@llvm-project//mlir:Transforms",
453454

455+
"@llvm-project//mlir:LLVMIRToLLVMTranslation",
456+
"@llvm-project//mlir:LLVMIRToNVVMTranslation",
457+
"@llvm-project//mlir:LLVMIRTransforms",
458+
454459
"@llvm-project//llvm:IRReader",
455460
"@llvm-project//llvm:Support",
456461
"@llvm-project//llvm:AArch64AsmParser",

0 commit comments

Comments
 (0)