Skip to content

Conversation

@oyazdanb
Copy link
Contributor

adding toy model and e2e testing for gpt-oss

@oyazdanb oyazdanb marked this pull request as ready for review October 15, 2025 15:15
@github-actions
Copy link
Contributor

github-actions bot commented Oct 15, 2025

Coverage report

Click to see where and how coverage changed

FileStatementsMissingCoverageCoverage
(new stmts)
Lines missing
  sharktank/sharktank/layers
  mixture_of_experts_block.py
  paged_llama_attention_block.py
  sharktank/sharktank/models/gpt_oss
  orig_pytorch_model.py 90-116, 328
  testing.py 184, 224
  toy_gpt_oss.py 31, 122-134, 220-280, 284-297, 301
  sharktank/sharktank/models/llama
  testing.py 50-59
  sharktank/sharktank/models/llm
  llm.py
  sharktank/sharktank/ops
  attention_impls.py
  default_impls.py
  sharktank/tests/models/gpt_oss
  toy_gpt_oss_test.py 40-56, 223
Project Total  

This report was generated by python-coverage-comment-action

Comment on lines +82 to +83
result2 = decoder.prefill_cross_entropy([self.sequence])[0]
self.assertEqual(result.score, result2.score)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this doing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is testing determinism between 2 prefills with same inputs, which once the model bring up is fairly stable shouldn't be a concern. So can be removed.

assert result.valid

shark_ce = 4.6970133781433105
torch.testing.assert_close(result.score, shark_ce, atol=1e-2, rtol=1e-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come the tolerances are so large?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, that's the tolerance used for toy llama too, cross entropy is the concern here. Added a comment above to resolve it.

Comment on lines +30 to +31
if ref_model is None:
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ref_model is None we should be erroring out not passing silently.

)


def copy_weights_to_reference(shark_theta, ref_model, hp):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not belong in the toy model file.

Comment on lines +118 to +134
def calculate_cross_entropy_manual(
model_instance, sequence: list[int], use_prefill: bool = True
) -> tuple[float, float]:
"""Calculate cross entropy and perplexity manually for debugging."""
evaluator = model_instance.make_perplexity_eval()
if use_prefill:
res = evaluator.prefill_cross_entropy([sequence])[0]
else:
res = evaluator.decode_cross_entropy([sequence])[0]

assert res.valid
ce = res.score
ppl = float(torch.exp(torch.tensor(ce)))

print("cross_entropy_nats:", ce)
print("perplexity:", ppl)
return ce, ppl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not belong in here.

Comment on lines +217 to +230
def make_simple_analytical_gpt_oss_theta(
config: LlamaModelConfig,
vocab_size: Optional[int] = None,
dtype_rest: torch.dtype = torch.bfloat16,
dtype_norm: torch.dtype = torch.bfloat16,
) -> Theta:
"""Generate a GPT-OSS theta with simple analytical weights for hand calculation."""
return make_random_gpt_oss_theta(
config=config,
vocab_size=vocab_size,
dtype_rest=dtype_rest,
dtype_norm=dtype_norm,
weight_generator=make_simple_calculable_weight_torch,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is not needed. All it does it wrap make_random_gpt_oss_theta and pass one extra argument.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can rename to ref_pytorch_model.py and if only used in testing, can be moved to tests/models/gpt_oss/

)


def make_gpt_oss_attention_block_theta(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we want to define theta layers in sharktank/layers/testing.py and construct the blocks accordingly under the respective /sharktank/models/<model_name>.
Also, try reusing existing make theta functions (like attn, moe) wherever possible and add additional layers like bias, as you construct it here.

The vision models are not following this pattern and needs to be consolidated to align with llms.

@@ -0,0 +1,386 @@
import json
Copy link
Collaborator

@archana-ramalingam archana-ramalingam Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the following to all the new files:

# Copyright 2025 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

block_count=block_count,
)
decoder = instance.make_decoder()
generated_tokens = decoder.greedy_decode([[0]], steps=14)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's necessary to link the input here directly, by replacing [[0]] by self.sequence[0]

Comment on lines +64 to +65
decoded = decoder.greedy_decode([[0]], steps=len(expected))[0]
decoded2 = decoder.greedy_decode([[0]], steps=len(expected))[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above about linking input.

decoded = decoder.greedy_decode([[0]], steps=len(expected))[0]
decoded2 = decoder.greedy_decode([[0]], steps=len(expected))[0]

self.assertEqual(decoded, decoded2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Determinism is verified by the assert below. This assert does almost the same thing, except run the test one more time. So can be removed.

return theta, config


def generate_analytical(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to consolidate both generate functions by having a default LlamaHParams and for analytical, only necessary args are changed?

ref_model.unembedding.weight.data = shark_theta("output", "weight").as_torch()


def calculate_cross_entropy_manual(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This essentially does what we have in the test below. I don't see a need to have this separately.

self.seed = 12345

# Hardcoded for CI performance - regenerate with self.generate_sequence() if weights change
self.sequence = [0, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we generate this input from the toy irpa model, so that the cross entropy isn't very high?
Start to prompt with 0
Run again with prefill or a subsequent decoder step:
[0] -> [2]
[0 2] -> [5]
[0 2 5] -> [9]
[0 2 5 9] -> [4]
...

Comment on lines +97 to +98
result2 = decoder.decode_cross_entropy([self.sequence])[0]
self.assertEqual(result.score, result2.score)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove all asserts verifying determinism.

"""Test reference and sharktank model e2e comparison."""

def setUp(self):
logging.basicConfig(level=logging.INFO)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be set globally for this whole test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can replace print statements here to use logging.

print(f"Full test sequence: {full_sequence}")
return full_sequence

def testDecodeSequence(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On that note, not sure if we need this test at all. Can stick with cross entropy for prefill/decode for eager/IREE modes.

count += 1

ref_ce = total_loss / count
ref_ppl = float(torch.exp(torch.tensor(ref_ce)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't seem to be verifying ref_ppl across both tests?

torch.testing.assert_close(
shark_result.score, expected_ce, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(ref_ce, expected_ce, atol=1e-2, rtol=1e-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be comparing ref_ce with shark_result.score?

torch.testing.assert_close(
shark_result.score, expected_ce, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(ref_ce, expected_ce, atol=1e-2, rtol=1e-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, compare ref_ce with shark_result.score?

ref_model.embedding.weight.data = shark_theta("token_embd", "weight").as_torch()

# Copy transformer blocks
for block_idx in range(hp.block_count):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mapping from safetensors to irpa or vice versa is usually more readable, if we use a mapping dict like here. Is this something we can leverage here too?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants