Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,12 @@ def main():
llama_config.hp.rope_interleave_emb = False

# Override matmul_kernel if the weights were shuffled
if dataset.properties.get("use_shuffled_kernel", False):
shuffle_version = dataset.properties.get("use_shuffled_kernel", False)
if shuffle_version:
kernel_selection = f"sharktank.asm.shuffled;{llama_config.matmul_kernel}"
logger.debug(f"Using preshuffle kernel variant: {kernel_selection}")
logger.debug(
f"Using preshuffle kernel variant: {kernel_selection} (version={shuffle_version})"
)
llama_config.matmul_kernel = kernel_selection

hp = llama_config.hp
Expand Down
8 changes: 8 additions & 0 deletions sharktank/sharktank/kernels/assembly_binaries.py

Large diffs are not rendered by default.

100 changes: 89 additions & 11 deletions sharktank/sharktank/kernels/gemm_fp4_asm.py

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions sharktank/sharktank/models/llama/tools/import_quark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
_optional_int_prop,
_int_prop,
)
from sharktank.kernels.gemm_fp4_asm import shuffle_weight
from sharktank.kernels.gemm_fp4_asm import shuffle_weight, shuffle_scale


def _load_json(p: Path):
Expand Down Expand Up @@ -133,9 +133,10 @@ def create_fp4_block_tensor(

expected_shape = list(original_shape[:-1]) + [num_blocks, packed_block_size]

# Apply weight shuffling during preprocessing to avoid runtime shuffling (if enabled)
# Apply weight and scale shuffling during preprocessing to avoid runtime shuffling (if enabled)
if apply_shuffle:
weight_tensor = shuffle_weight(weight_tensor, layout=(16, 16))
scale_tensor = shuffle_scale(scale_tensor)
weight_tensor = weight_tensor.view(*expected_shape)

layout = BlockScaledFp4Layout(
Expand Down Expand Up @@ -514,7 +515,11 @@ def main(argv):
updated_properties = convert_hf_hparams_to_gguf(ds.properties)

# Store shuffle configuration for kernel selection
updated_properties["use_shuffled_kernel"] = args.apply_shuffle
# Version tracking: True (v1, weights only), "v2" (weights + scales), False (no preshuffle)
if args.apply_shuffle:
updated_properties["use_shuffled_kernel"] = "v2" # weights + scales preshuffled
else:
updated_properties["use_shuffled_kernel"] = False

head_count = (updated_properties["llama.attention.head_count"],)

Expand Down
Loading