Skip to content

Commit 73e6026

Browse files
authored
Merge branch 'main' into main
2 parents ab3a653 + 94e9576 commit 73e6026

38 files changed

+2863
-1013
lines changed

.buildkite/pipeline.yml

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,76 @@
11
steps:
2-
- label: "CUDA Julia v{{matrix.version}} -- {{matrix.group}}"
3-
matrix:
4-
setup:
5-
version:
6-
- "1.10"
7-
group:
8-
- core
9-
- neural_networks
10-
- integration
11-
plugins:
12-
- JuliaCI/julia#v1:
13-
version: "{{matrix.version}}"
14-
- JuliaCI/julia-coverage#v1:
15-
codecov: true
16-
dirs:
17-
- src
18-
- ext
19-
- lib/ReactantCore/src
20-
commands: |
21-
julia --project=. -e 'println("--- :julia: Instantiating project")
22-
using Pkg
23-
Pkg.develop([PackageSpec(path="lib/ReactantCore")])'
2+
- group: ":test_tube: Tests"
3+
steps:
4+
- label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}}"
5+
matrix:
6+
setup:
7+
version:
8+
- "1.10"
9+
group:
10+
- core
11+
- neural_networks
12+
- integration
13+
plugins:
14+
- JuliaCI/julia#v1:
15+
version: "{{matrix.version}}"
16+
- JuliaCI/julia-coverage#v1:
17+
codecov: true
18+
dirs:
19+
- src
20+
- ext
21+
- lib/ReactantCore/src
22+
commands: |
23+
julia --project=. -e 'println("--- :julia: Instantiating project")
24+
using Pkg
25+
Pkg.develop([PackageSpec(path="lib/ReactantCore")])'
26+
27+
julia --project=. -e 'println("--- :julia: Run Tests")
28+
using Pkg
29+
Pkg.test(; coverage="user")'
30+
agents:
31+
queue: "juliagpu"
32+
cuda: "*"
33+
env:
34+
REACTANT_TEST_GROUP: "{{matrix.group}}"
35+
if: build.message !~ /\[skip tests\]/
36+
timeout_in_minutes: 120
37+
38+
- label: ":julia: :linux: aarch64 - Julia v{{matrix.version}} -- {{matrix.group}}"
39+
matrix:
40+
setup:
41+
version:
42+
- "1.10"
43+
- "1.11"
44+
group:
45+
- core
46+
- neural_networks
47+
- integration
48+
plugins:
49+
- JuliaCI/julia#v1:
50+
version: "{{matrix.version}}"
51+
- JuliaCI/julia-coverage#v1:
52+
codecov: true
53+
dirs:
54+
- src
55+
- ext
56+
- lib/ReactantCore/src
57+
commands: |
58+
julia --project=. -e 'println("--- :julia: Instantiating project")
59+
using Pkg
60+
Pkg.develop([PackageSpec(path="lib/ReactantCore")])'
2461
25-
julia --project=. -e 'println("--- :julia: Run Tests")
26-
using Pkg
27-
Pkg.test(; coverage="user")'
28-
agents:
29-
queue: "juliagpu"
30-
cuda: "*"
31-
env:
32-
REACTANT_TEST_GROUP: "{{matrix.group}}"
33-
if: build.message !~ /\[skip tests\]/
34-
timeout_in_minutes: 60
62+
julia --project=. -e 'println("--- :julia: Run Tests")
63+
using Pkg
64+
Pkg.test(; coverage="user")'
65+
agents:
66+
queue: "juliaecosystem"
67+
os: "linux"
68+
sandbox_capable: "true"
69+
arch: "aarch64"
70+
env:
71+
REACTANT_TEST_GROUP: "{{matrix.group}}"
72+
if: build.message !~ /\[skip tests\]/
73+
timeout_in_minutes: 120
3574

3675
- group: ":racehorse: Benchmarks"
3776
steps:

.github/workflows/CI.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ jobs:
3333
- integration
3434
arch:
3535
- x64
36+
- aarch64
3637
assertions:
3738
- false
3839
libReactant: [packaged]
@@ -49,6 +50,12 @@ jobs:
4950
version: '1.10'
5051
assertions: true
5152
test_group: neural_networks
53+
- os: ubuntu-20.04
54+
arch: x64
55+
libReactant: packaged
56+
version: '1.10'
57+
assertions: true
58+
test_group: integration
5259
- os: ubuntu-20.04
5360
arch: x86
5461
libReactant: packaged
@@ -64,6 +71,10 @@ jobs:
6471
libReactant: packaged
6572
version: '1.10'
6673
test_group: integration
74+
exclude:
75+
# these are run on Buildkite
76+
- os: ubuntu-20.04
77+
arch: aarch64
6778
steps:
6879
- uses: actions/checkout@v4
6980
- uses: julia-actions/setup-julia@v2

Project.toml

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,29 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1616
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
17+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1718
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
1819
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
1920
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
2021

2122
[weakdeps]
2223
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
2324
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
25+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2426
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
27+
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
2528
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2629
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
2730

28-
[sources.ReactantCore]
29-
path = "lib/ReactantCore"
31+
[sources]
32+
ReactantCore = {path = "lib/ReactantCore"}
3033

3134
[extensions]
3235
ReactantAbstractFFTsExt = "AbstractFFTs"
3336
ReactantArrayInterfaceExt = "ArrayInterface"
37+
ReactantCUDAExt = "CUDA"
3438
ReactantNNlibExt = "NNlib"
39+
ReactantRandom123Ext = "Random123"
3540
ReactantStatisticsExt = "Statistics"
3641
ReactantYaoBlocksExt = "YaoBlocks"
3742

@@ -40,15 +45,18 @@ AbstractFFTs = "1.5"
4045
Adapt = "4"
4146
ArrayInterface = "7.10"
4247
CEnum = "0.4, 0.5"
48+
CUDA = "5"
4349
Downloads = "1.6"
44-
Enzyme = "0.13.21"
50+
Enzyme = "0.13.22"
4551
EnzymeCore = "0.8.8"
4652
GPUArraysCore = "0.1.6, 0.2"
4753
LinearAlgebra = "1.10"
4854
NNlib = "0.9.26"
4955
OrderedCollections = "1"
5056
Preferences = "1.4"
51-
ReactantCore = "0.1.2"
57+
Random = "1.10"
58+
Random123 = "1.7"
59+
ReactantCore = "0.1.3"
5260
Reactant_jll = "0.0.26"
5361
Scratch = "1.2"
5462
Statistics = "1.10"
@@ -58,4 +66,5 @@ julia = "1.10"
5866
[extras]
5967
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
6068
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
69+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
6170
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

deps/ReactantExtra/API.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,16 @@ extern "C" MlirModule ConvertLLVMToMLIR(LLVMModuleRef lmod, MlirContext cctx) {
376376
return wrap(res);
377377
}
378378

379+
#include "llvm/IRReader/IRReader.h"
380+
extern "C" MlirModule ConvertLLVMStrToMLIR(const char* lmod, MlirContext cctx) {
381+
LLVMContext Context;
382+
SMDiagnostic Err;
383+
auto llvmModule = llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Err, Context);
384+
mlir::MLIRContext &context = *unwrap(cctx);
385+
auto res = mlir::translateLLVMIRToModule(std::move(llvmModule), &context, /*emitExpensiveWarnings*/false, /*dropDICompositeElements*/false).release();
386+
return wrap(res);
387+
}
388+
379389

380390
/* Note that this */
381391
extern "C" xla::PjRtLoadedExecutable* ClientCompile(PjRtClient * client, MlirModule cmod) {
@@ -460,6 +470,10 @@ extern "C" void RegisterDialects(MlirContext cctx) {
460470
context.loadDialect<mlir::stablehlo::StablehloDialect>();
461471
context.loadDialect<mlir::chlo::ChloDialect>();
462472
}
473+
474+
#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
475+
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
476+
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
463477
extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
464478
mlir::DialectRegistry &registry = *unwrap(creg);
465479

@@ -503,6 +517,11 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
503517
mlir::affine::registerAffinePasses();
504518
mlir::registerReconcileUnrealizedCasts();
505519

520+
mlir::registerLLVMDialectImport(registry);
521+
mlir::registerNVVMDialectImport(registry);
522+
523+
mlir::LLVM::registerInlinerInterface(registry);
524+
506525
/*
507526
registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
508527
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
@@ -530,6 +549,81 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
530549
mlir::enzyme::registerEnzymeJaxTransformExtension(registry);
531550
}
532551

552+
553+
/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
554+
/// suffix in `lastUsedID`.
555+
static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName,
556+
unsigned &lastUsedID,
557+
mlir::ModuleOp source,
558+
mlir::ModuleOp target) {
559+
using namespace llvm;
560+
using namespace mlir;
561+
SmallString<64> newSymName(oldSymName);
562+
newSymName.push_back('_');
563+
while (true) {
564+
auto possible = newSymName + Twine(++lastUsedID);
565+
if (!SymbolTable::lookupSymbolIn(source, possible.str()) && !SymbolTable::lookupSymbolIn(target, possible.str())) {
566+
return StringAttr::get(target.getContext(), possible);
567+
}
568+
}
569+
}
570+
571+
572+
/// Checks if a symbol with the same name as `op` already exists in `source`.
573+
/// If so, renames `op` and updates all its references in `target`.
574+
static mlir::LogicalResult
575+
updateSymbolAndAllUses(mlir::SymbolOpInterface op, mlir::ModuleOp source, mlir::ModuleOp target,
576+
unsigned &lastUsedID) {
577+
using namespace llvm;
578+
using namespace mlir;
579+
580+
auto opName = op.getName().str();
581+
582+
if (!SymbolTable::lookupSymbolIn(target, opName)) {
583+
return success();
584+
}
585+
586+
StringAttr newSymName =
587+
renameSymbol(opName, lastUsedID, source, target);
588+
589+
if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, source)))
590+
return op.emitError("unable to update all symbol uses for ")
591+
<< opName << " to " << newSymName;
592+
593+
SymbolTable::setSymbolName(op, newSymName);
594+
return success();
595+
}
596+
597+
extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, const char* entryfn) {
598+
auto prevMod = cast<ModuleOp>(*unwrap(prevModC));
599+
auto newMod = cast<ModuleOp>(*unwrap(newModC));
600+
601+
Operation* entryFn = nullptr;
602+
603+
unsigned lastUsedID = 0;
604+
605+
for (auto &op : *newMod.getBody()) {
606+
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
607+
if (!symbolOp)
608+
continue;
609+
610+
StringRef oldSymName = symbolOp.getName();
611+
612+
if (oldSymName == entryfn) {
613+
entryFn = &op;
614+
}
615+
616+
if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod,
617+
lastUsedID))) {
618+
assert(0 && "failed to update all uses");
619+
}
620+
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
621+
}
622+
prevMod.getBody()->getOperations().splice(prevMod.getBody()->getOperations().end(),
623+
newMod.getBody()->getOperations());
624+
return wrap(entryFn);
625+
}
626+
533627
#pragma region xla::ifrt
534628

535629
#pragma region xla::ifrt::Value

deps/ReactantExtra/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ cc_library(
416416
"-Wl,-exported_symbol,_BufferToHost",
417417
"-Wl,-exported_symbol,_FreeClient",
418418
"-Wl,-exported_symbol,_ClientCompile",
419+
"-Wl,-exported_symbol,_LinkInModule",
419420
"-Wl,-exported_symbol,_FreeFuture",
420421
"-Wl,-exported_symbol,_FutureIsReady",
421422
"-Wl,-exported_symbol,_FutureAwait",
@@ -450,6 +451,12 @@ cc_library(
450451
"@llvm-project//mlir:SCFDialect",
451452
"@llvm-project//mlir:TransformDialect",
452453
"@llvm-project//mlir:Transforms",
454+
455+
"@llvm-project//mlir:LLVMIRToLLVMTranslation",
456+
"@llvm-project//mlir:LLVMIRToNVVMTranslation",
457+
"@llvm-project//mlir:LLVMIRTransforms",
458+
459+
"@llvm-project//llvm:IRReader",
453460
"@llvm-project//llvm:Support",
454461
"@llvm-project//llvm:AArch64AsmParser",
455462
"@llvm-project//llvm:AArch64CodeGen",

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ pages = [
4343
],
4444
"MLIR API" => "api/mlirc.md",
4545
"XLA" => "api/xla.md",
46+
"Internal API" => "api/internal.md",
4647
],
4748
]
4849

docs/src/.vitepress/config.mts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ export default defineConfig({
7878
{ text: "MLIR API", link: "/api/mlirc" },
7979
{ text: "XLA", link: "/api/xla" },
8080
],
81-
}
81+
},
82+
{ text: "Internal API", link: "/api/internal" },
8283
],
8384
},
8485
{
@@ -132,6 +133,7 @@ export default defineConfig({
132133
{ text: "XLA", link: "/api/xla" },
133134
],
134135
},
136+
{ text: "Internal API", link: "/api/internal" },
135137
],
136138
},
137139
},

docs/src/api/internal.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
```@meta
2+
CollapsedDocStrings = true
3+
```
4+
5+
# Internal API
6+
7+
These functions are not part of the public API and are subject to change at any time.
8+
9+
```@docs
10+
Reactant.REDUB_ARGUMENTS_NAME
11+
Reactant.within_reactant_interpreter
12+
```

0 commit comments

Comments
 (0)