Skip to content

Commit 704a44a

Browse files
Patch conversion script qwen3 moe (#2425)
* chore- save preset * format fix * comments
1 parent 93d89d8 commit 704a44a

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

tools/checkpoint_conversion/convert_qwen3_moe_checkpoints.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
PRESET_MAP = {
2323
"qwen3_moe_30b_a3b_en": "Qwen/Qwen3-30B-A3B",
24+
"qwen3_moe_235b_a22b_en": "Qwen/Qwen3-235B-A22B",
2425
}
2526

2627
FLAGS = flags.FLAGS
@@ -85,21 +86,11 @@ def test_tokenizer(keras_hub_tokenizer, hf_tokenizer):
8586
np.testing.assert_equal(keras_hub_output, hf_output)
8687

8788

88-
def validate_output(
89-
keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer
90-
):
89+
def validate_output(qwen3_moe_lm, hf_model, hf_tokenizer):
9190
input_str = "What is Keras?"
9291
length = 32
9392

94-
# KerasHub
95-
preprocessor = keras_hub.models.Qwen3MoeCausalLMPreprocessor(
96-
keras_hub_tokenizer
97-
)
98-
qwen_moe_lm = keras_hub.models.Qwen3MoeCausalLM(
99-
backbone=keras_hub_model, preprocessor=preprocessor, sampler="greedy"
100-
)
101-
102-
keras_output = qwen_moe_lm.generate([input_str], max_length=length)
93+
keras_output = qwen3_moe_lm.generate([input_str], max_length=length)
10394
keras_output = keras_output[0]
10495
print("🔶 KerasHub output:", keras_output)
10596

@@ -150,11 +141,16 @@ def main(_):
150141
test_tokenizer(keras_hub_tokenizer, hf_tokenizer)
151142
test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer)
152143

153-
# == Validate model.generate output ==
154-
validate_output(
155-
keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer
144+
preprocessor = keras_hub.models.Qwen3MoeCausalLMPreprocessor(
145+
keras_hub_tokenizer
156146
)
147+
qwen3_moe_lm = keras_hub.models.Qwen3MoeCausalLM(
148+
backbone=keras_hub_model, preprocessor=preprocessor, sampler="greedy"
149+
)
150+
# == Validate model.generate output ==
151+
validate_output(qwen3_moe_lm, hf_model, hf_tokenizer)
157152
print("\n-> Tests passed!")
153+
qwen3_moe_lm.save_to_preset(f"./{preset}")
158154

159155

160156
if __name__ == "__main__":

0 commit comments

Comments
 (0)