Add canonical Grug MoE module, template variant, and contracts#3046
Add canonical Grug MoE module, template variant, and contracts#3046
Conversation
|
@claude ptal |
|
Claude finished @dlwh's task in 2m 57s —— View job 🤖
Review: Refactor Grug MoE core and activation handlingOverall this is a clean, well-structured MoE implementation. The separation of the functional Agreeing with self-review items1. jaxtyping annotations on +1 to the self-review. The existing 2. Comment wording (line 260) Agree with the suggested edit: 3. Agree. Currently batch_spec = x.sharding.spec[:1] + (P(None),) * (x.ndim - 1)or simply pulling the first partition from Additional observations4. Hardcoded activation in The functional 5. Activation dispatch: if callable(activation):
activation_fn = activation
else:
activation_fn = activation.to_jax_fn()Since if isinstance(activation, ActivationFunctionEnum):
activation_fn = activation.to_jax_fn()
else:
activation_fn = activation # already a CallableThis is more explicit about what the union type 6. if mesh is None:
mesh = get_abstract_mesh() # line 343-344
...
if mesh is None or mesh.empty: # line 364If
Minor nit but worth clarifying the contract. 7. Duplicated routing logic between Both functions independently compute 8. EP path group_sizes = jnp.bincount(expert_local, weights=valid_i32, length=local_experts).astype(jnp.int32)
9. EP
10. Test coverage Test coverage is solid: EP mesh, dense mesh, abstract mesh lowering, loss, shared expert disable, and functional activation enum vs callable equivalence. Two suggestions:
11. Clean and correct. SummaryGood refactor overall. The main actionable items are:
|
|
🤖 Followed up on the extra Claude review points as part of commit
I did not make |
|
@claude ptal |
|
Claude finished @dlwh's task in 2m 27s —— View job 🤖
Review: Grug MoE Module, Template Variant, and ContractsThe core MoE implementation in I have a few items below, ranging from structural concerns to minor nits. 1. Near-total duplication between
|
|
🤖 Grug variant diff report
|
Summary
lib/levanter/src/levanter/grug/grug_moe.py(Equinox style, EP-aware dispatch/collect, optional shared dense expert)router matmul + top-k + softmax) inline intoMoEMLP.__call__, and makemoe_mlpa reusable dispatch/permute/unpermute (+EP) kernel that consumes precomputedselected_expertsandcombine_weightsexperiments/grug/moe/(model.py,train.py,launch.py,__init__.py) aligned withexperiments/grug/baseexperiments/grug/test_variant_contracts.pyand standardize per-variant lowering hooks viadebug_mesh_and_token_pspec(num_devices)in variant model modulesActivationFunctionEnumconsistently and extendlevanter.utils.activationwithrelu2Validation
uv run --python 3.11 --package levanter --group test pytest lib/levanter/tests/grug/test_grugformer_moe.py -quv run --python 3.11 --package levanter --group test pytest lib/levanter/tests/grug -quv run pytest -o addopts='' tests/test_grug_base_template.py experiments/grug/test_variant_contracts.py -qNotes
--no-verifybecause the repo-wide pyrefly pre-commit hook currently fails with an existing project-exclude pattern error unrelated to this diff.