Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
8de5bb5
init einsum
phu0ngng Dec 3, 2025
1f02cf4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 3, 2025
bf3ebc2
code drop
pggPL Dec 10, 2025
76293d4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2025
296d773
Add FP8 scale support and fix alignment for grouped GEMM
pggPL Dec 10, 2025
785df34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2025
1329b37
fix
pggPL Dec 10, 2025
47c58be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2025
a155a8a
Grouped GEMM: code cleanup and NULL C support
pggPL Dec 11, 2025
3b2fcdf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
5b0582b
Grouped GEMM: per-matrix alpha/beta support
pggPL Dec 11, 2025
101766b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
1167f75
Fix alpha/beta numel - use SimpleTensor::numel()
pggPL Dec 11, 2025
a5ee92f
Merge branch 'main' into einsum
jberchtold-nvidia Dec 16, 2025
00eb186
Einsum WIP 1
jberchtold-nvidia Dec 17, 2025
38defb8
Test
jberchtold-nvidia Dec 18, 2025
e4a80a3
Refactor: move grouped GEMM to separate file and cleanup API
pggPL Dec 19, 2025
db1e177
Merge branch 'main' into grouped_gemm
pggPL Dec 19, 2025
047a9f9
fix
pggPL Dec 19, 2025
c490e06
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2025
e397845
batching working correctly for quant and gemm but slow
jberchtold-nvidia Dec 19, 2025
59145cc
fix
pggPL Dec 22, 2025
77b422a
Require Blackwell (SM100) and cuBLAS 13.1+ for grouped GEMM
pggPL Dec 22, 2025
9c8158e
fix
pggPL Dec 22, 2025
b1e0893
fix
jberchtold-nvidia Dec 22, 2025
f70f376
Merge remote-tracking branch 'github-upstream/main' into einsum
jberchtold-nvidia Dec 23, 2025
fb2067b
move einsum logic into TE
jberchtold-nvidia Dec 23, 2025
30716a6
einsum unit tests
jberchtold-nvidia Dec 23, 2025
349c315
fwd bwd einsum test
jberchtold-nvidia Dec 23, 2025
57ab3b0
unit tests passed with grouped gemm in bf16
jberchtold-nvidia Dec 23, 2025
ab98852
grouped quantization working for single gpu
jberchtold-nvidia Dec 23, 2025
1184796
Merge remote-tracking branch 'pawel/grouped_gemm' into einsum
jberchtold-nvidia Dec 23, 2025
f1fc31c
wip
jberchtold-nvidia Jan 5, 2026
c8cf763
with many hacks grouped gemm with new api works for a particular hard…
jberchtold-nvidia Jan 7, 2026
21e7002
progress
jberchtold-nvidia Jan 7, 2026
1ae08dd
more tests pass
jberchtold-nvidia Jan 7, 2026
fe39e39
einsum tests pass
jberchtold-nvidia Jan 7, 2026
5e47d57
more progress, works in maxtext single-gpu and is closer to bf16 batc…
jberchtold-nvidia Jan 8, 2026
bc6cf66
attempt at passing thru stateful args for DS
jberchtold-nvidia Jan 8, 2026
bcbe864
Revert "attempt at passing thru stateful args for DS"
jberchtold-nvidia Jan 8, 2026
b40353f
batch gemm specialization for CS amax calc
jberchtold-nvidia Jan 8, 2026
ee71c96
multi-GPU grouped quantize working now in shard_map (with hack to use…
jberchtold-nvidia Jan 15, 2026
9856862
reduce size of zero'ing memset to only uninitialized part of quantiza…
jberchtold-nvidia Jan 15, 2026
f58ba23
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions test_einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from enum import Enum

import jax
import jax.numpy as jnp
import numpy as np
import transformer_engine.jax as te
from transformer_engine.common.recipe import (
Recipe,
Float8CurrentScaling,
MXFP8BlockScaling,
DelayedScaling,
NVFP4BlockScaling,
)
from flax import linen as nn


def make_einsum_cls(quantization_recipe):
def te_einsum(generate_quantizer_set, s, x, kernel, **kwargs):
def dot_general(x, kernel, dims, *args, **kwargs):
contracting_dims, batch_dims = dims
assert batch_dims == ((), ()), "Batch dims not supported in TE/JAX yet"

quantizer_set = generate_quantizer_set("quantizer_set_for_einsum")
return te.dense.dense(
x,
kernel,
contracting_dims=contracting_dims,
quantizer_set=quantizer_set,
)

return jnp.einsum(s, x, kernel, _dot_general=dot_general, **kwargs)

return te.flax.wrap_function_in_te_state_module(te_einsum, quantization_recipe, "einsum")()


class EinsumType(Enum):
JAX = "jax"
TE = "te"


def main():

class SimpleModel(nn.Module):

einsum_type: EinsumType
quantization_recipe: Recipe = None

def _einsum(self, *args, **kwargs):
if self.einsum_type == EinsumType.JAX:
return jnp.einsum(*args, **kwargs)
elif self.einsum_type == EinsumType.TE:
# It is important that we call make_einsum_cls(recipe) here each time einsum
# is called. If we were to call make_einsum_cls only once and re-use it, the state for some recipes such as DelayedScaling would become incorrectly shared instead of each call having its own state.
return make_einsum_cls(self.quantization_recipe)(*args, **kwargs)
else:
raise ValueError(f"Unsupported einsum type: {self.einsum_type}")

@nn.compact
def __call__(self, x):
kernel = self.param(
"kernel", jax.nn.initializers.lecun_normal(), (32, 32), jnp.bfloat16
)
return self._einsum("ij,jk->ik", x, kernel)

def test_model(einsum_type: EinsumType, quantization_recipe: Recipe = None):
model = SimpleModel(einsum_type=einsum_type, quantization_recipe=quantization_recipe)
x = jax.random.uniform(jax.random.PRNGKey(2), (32, 32), jnp.bfloat16)
var_collect = model.init(jax.random.PRNGKey(3), x)
# It is important to use var_collect here to ensure all state (e.g., quantizer states) is properly handled. If you use var_collect['params'] only, TE's state management will not work correctly for recipes that require state (e.g. DelayedScaling).
y = model.apply(var_collect, x)
return y

# einsum_cls = None, so standard JAX computation
ref_out = test_model(einsum_type=EinsumType.JAX)

# einsum using Transformer Engine's Float8CurrentScaling recipe
te_out = test_model(einsum_type=EinsumType.TE, quantization_recipe=Float8CurrentScaling())

# Compare outputs
atol = float(jnp.finfo(jnp.float8_e4m3fn).eps)
np.testing.assert_allclose(ref_out, te_out, atol=atol)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ add_executable(test_operator
test_causal_softmax.cu
test_swizzle.cu
test_swap_first_dims.cu
test_grouped_gemm.cu
../test_common.cu)

# Find required packages
Expand Down
Loading