-
Notifications
You must be signed in to change notification settings - Fork 70
[sharktank] toy model for gpt-oss #2516
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Coverage reportClick to see where and how coverage changed
This report was generated by python-coverage-comment-action |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result2 = decoder.prefill_cross_entropy([self.sequence])[0] | ||
| self.assertEqual(result.score, result2.score) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is testing determinism between 2 prefills with same inputs, which once the model bring up is fairly stable shouldn't be a concern. So can be removed.
| assert result.valid | ||
|
|
||
| shark_ce = 4.6970133781433105 | ||
| torch.testing.assert_close(result.score, shark_ce, atol=1e-2, rtol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come the tolerances are so large?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, that's the tolerance used for toy llama too, cross entropy is the concern here. Added a comment above to resolve it.
| if ref_model is None: | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If ref_model is None we should be erroring out not passing silently.
| ) | ||
|
|
||
|
|
||
| def copy_weights_to_reference(shark_theta, ref_model, hp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not belong in the toy model file.
| def calculate_cross_entropy_manual( | ||
| model_instance, sequence: list[int], use_prefill: bool = True | ||
| ) -> tuple[float, float]: | ||
| """Calculate cross entropy and perplexity manually for debugging.""" | ||
| evaluator = model_instance.make_perplexity_eval() | ||
| if use_prefill: | ||
| res = evaluator.prefill_cross_entropy([sequence])[0] | ||
| else: | ||
| res = evaluator.decode_cross_entropy([sequence])[0] | ||
|
|
||
| assert res.valid | ||
| ce = res.score | ||
| ppl = float(torch.exp(torch.tensor(ce))) | ||
|
|
||
| print("cross_entropy_nats:", ce) | ||
| print("perplexity:", ppl) | ||
| return ce, ppl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not belong in here.
| def make_simple_analytical_gpt_oss_theta( | ||
| config: LlamaModelConfig, | ||
| vocab_size: Optional[int] = None, | ||
| dtype_rest: torch.dtype = torch.bfloat16, | ||
| dtype_norm: torch.dtype = torch.bfloat16, | ||
| ) -> Theta: | ||
| """Generate a GPT-OSS theta with simple analytical weights for hand calculation.""" | ||
| return make_random_gpt_oss_theta( | ||
| config=config, | ||
| vocab_size=vocab_size, | ||
| dtype_rest=dtype_rest, | ||
| dtype_norm=dtype_norm, | ||
| weight_generator=make_simple_calculable_weight_torch, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is not needed. All it does it wrap make_random_gpt_oss_theta and pass one extra argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can rename to ref_pytorch_model.py and if only used in testing, can be moved to tests/models/gpt_oss/
| ) | ||
|
|
||
|
|
||
| def make_gpt_oss_attention_block_theta( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we want to define theta layers in sharktank/layers/testing.py and construct the blocks accordingly under the respective /sharktank/models/<model_name>.
Also, try reusing existing make theta functions (like attn, moe) wherever possible and add additional layers like bias, as you construct it here.
The vision models are not following this pattern and needs to be consolidated to align with llms.
| @@ -0,0 +1,386 @@ | |||
| import json | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the following to all the new files:
# Copyright 2025 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
| block_count=block_count, | ||
| ) | ||
| decoder = instance.make_decoder() | ||
| generated_tokens = decoder.greedy_decode([[0]], steps=14)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's necessary to link the input here directly, by replacing [[0]] by self.sequence[0]
| decoded = decoder.greedy_decode([[0]], steps=len(expected))[0] | ||
| decoded2 = decoder.greedy_decode([[0]], steps=len(expected))[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above about linking input.
| decoded = decoder.greedy_decode([[0]], steps=len(expected))[0] | ||
| decoded2 = decoder.greedy_decode([[0]], steps=len(expected))[0] | ||
|
|
||
| self.assertEqual(decoded, decoded2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Determinism is verified by the assert below. This assert does almost the same thing, except run the test one more time. So can be removed.
| return theta, config | ||
|
|
||
|
|
||
| def generate_analytical( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to consolidate both generate functions by having a default LlamaHParams and for analytical, only necessary args are changed?
| ref_model.unembedding.weight.data = shark_theta("output", "weight").as_torch() | ||
|
|
||
|
|
||
| def calculate_cross_entropy_manual( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This essentially does what we have in the test below. I don't see a need to have this separately.
| self.seed = 12345 | ||
|
|
||
| # Hardcoded for CI performance - regenerate with self.generate_sequence() if weights change | ||
| self.sequence = [0, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we generate this input from the toy irpa model, so that the cross entropy isn't very high?
Start to prompt with 0
Run again with prefill or a subsequent decoder step:
[0] -> [2]
[0 2] -> [5]
[0 2 5] -> [9]
[0 2 5 9] -> [4]
...
| result2 = decoder.decode_cross_entropy([self.sequence])[0] | ||
| self.assertEqual(result.score, result2.score) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can remove all asserts verifying determinism.
| """Test reference and sharktank model e2e comparison.""" | ||
|
|
||
| def setUp(self): | ||
| logging.basicConfig(level=logging.INFO) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be set globally for this whole test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, can replace print statements here to use logging.
| print(f"Full test sequence: {full_sequence}") | ||
| return full_sequence | ||
|
|
||
| def testDecodeSequence(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On that note, not sure if we need this test at all. Can stick with cross entropy for prefill/decode for eager/IREE modes.
| count += 1 | ||
|
|
||
| ref_ce = total_loss / count | ||
| ref_ppl = float(torch.exp(torch.tensor(ref_ce))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't seem to be verifying ref_ppl across both tests?
| torch.testing.assert_close( | ||
| shark_result.score, expected_ce, atol=1e-2, rtol=1e-2 | ||
| ) | ||
| torch.testing.assert_close(ref_ce, expected_ce, atol=1e-2, rtol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we be comparing ref_ce with shark_result.score?
| torch.testing.assert_close( | ||
| shark_result.score, expected_ce, atol=1e-2, rtol=1e-2 | ||
| ) | ||
| torch.testing.assert_close(ref_ce, expected_ce, atol=1e-2, rtol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, compare ref_ce with shark_result.score?
| ref_model.embedding.weight.data = shark_theta("token_embd", "weight").as_torch() | ||
|
|
||
| # Copy transformer blocks | ||
| for block_idx in range(hp.block_count): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mapping from safetensors to irpa or vice versa is usually more readable, if we use a mapping dict like here. Is this something we can leverage here too?
adding toy model and e2e testing for gpt-oss