Skip to content

Commit b0c14e3

Browse files
authored
n_layer should be set when load from pretrained (#435)
1 parent 181145a commit b0c14e3

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

example/gptj/jax_gptj.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,16 @@
5555
print("---- Model loading", flush=True)
5656
model_id = args.model_id
5757
if args.dtype == "float16":
58-
model = FlaxGPTJForCausalLM.from_pretrained(model_id, dtype=jax.numpy.float16)
58+
model = FlaxGPTJForCausalLM.from_pretrained(
59+
pretrained_model_name_or_path=model_id, n_layer=args.num_layer, dtype=jax.numpy.float16)
5960
model.params = model.to_fp16(model.params)
6061
elif args.dtype == "bfloat16":
61-
model = FlaxGPTJForCausalLM.from_pretrained(model_id, dtype=jax.numpy.bfloat16)
62+
model = FlaxGPTJForCausalLM.from_pretrained(
63+
pretrained_model_name_or_path=model_id, n_layer=args.num_layer, dtype=jax.numpy.bfloat16)
6264
model.params = model.to_bf16(model.params)
6365
else:
64-
model = FlaxGPTJForCausalLM.from_pretrained(model_id, dtype=jax.numpy.float32)
65-
model.config.n_layer = args.num_layer
66+
model = FlaxGPTJForCausalLM.from_pretrained(
67+
pretrained_model_name_or_path=model_id, n_layer=args.num_layer, dtype=jax.numpy.float32)
6668
print(model.config)
6769
print("---- Model loading done", flush=True)
6870

0 commit comments

Comments
 (0)