|
21 | 21 |
|
22 | 22 | PRESET_MAP = {
|
23 | 23 | "qwen3_moe_30b_a3b_en": "Qwen/Qwen3-30B-A3B",
|
| 24 | + "qwen3_moe_235b_a22b_en": "Qwen/Qwen3-235B-A22B", |
24 | 25 | }
|
25 | 26 |
|
26 | 27 | FLAGS = flags.FLAGS
|
@@ -85,21 +86,11 @@ def test_tokenizer(keras_hub_tokenizer, hf_tokenizer):
|
85 | 86 | np.testing.assert_equal(keras_hub_output, hf_output)
|
86 | 87 |
|
87 | 88 |
|
88 |
| -def validate_output( |
89 |
| - keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer |
90 |
| -): |
| 89 | +def validate_output(qwen3_moe_lm, hf_model, hf_tokenizer): |
91 | 90 | input_str = "What is Keras?"
|
92 | 91 | length = 32
|
93 | 92 |
|
94 |
| - # KerasHub |
95 |
| - preprocessor = keras_hub.models.Qwen3MoeCausalLMPreprocessor( |
96 |
| - keras_hub_tokenizer |
97 |
| - ) |
98 |
| - qwen_moe_lm = keras_hub.models.Qwen3MoeCausalLM( |
99 |
| - backbone=keras_hub_model, preprocessor=preprocessor, sampler="greedy" |
100 |
| - ) |
101 |
| - |
102 |
| - keras_output = qwen_moe_lm.generate([input_str], max_length=length) |
| 93 | + keras_output = qwen3_moe_lm.generate([input_str], max_length=length) |
103 | 94 | keras_output = keras_output[0]
|
104 | 95 | print("🔶 KerasHub output:", keras_output)
|
105 | 96 |
|
@@ -150,11 +141,16 @@ def main(_):
|
150 | 141 | test_tokenizer(keras_hub_tokenizer, hf_tokenizer)
|
151 | 142 | test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer)
|
152 | 143 |
|
153 |
| - # == Validate model.generate output == |
154 |
| - validate_output( |
155 |
| - keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer |
| 144 | + preprocessor = keras_hub.models.Qwen3MoeCausalLMPreprocessor( |
| 145 | + keras_hub_tokenizer |
156 | 146 | )
|
| 147 | + qwen3_moe_lm = keras_hub.models.Qwen3MoeCausalLM( |
| 148 | + backbone=keras_hub_model, preprocessor=preprocessor, sampler="greedy" |
| 149 | + ) |
| 150 | + # == Validate model.generate output == |
| 151 | + validate_output(qwen3_moe_lm, hf_model, hf_tokenizer) |
157 | 152 | print("\n-> Tests passed!")
|
| 153 | + qwen3_moe_lm.save_to_preset(f"./{preset}") |
158 | 154 |
|
159 | 155 |
|
160 | 156 | if __name__ == "__main__":
|
|
0 commit comments