Skip to content

Commit 3bdb36e

Browse files
Manan17Manan Shahvaibhavjindal
authored
Fix ci (#853)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This fixes the CI. The mllama model does not support the accum_dtype param for transfomers version 4.49.0. The qwen model required out_hidden_size param in the config. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: H100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Manan Shah <[email protected]> Co-authored-by: Vaibhav Jindal <[email protected]>
1 parent 90d66ce commit 3bdb36e

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

src/liger_kernel/transformers/model/mllama.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ def lce_forward(
190190
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
191191
)
192192
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
193-
193+
# Filter out accum_dtype from kwargs for model call as MllamaTextModel doesn't accept it in transformers 4.49.0
194+
# but preserve it for loss function calls
195+
model_kwargs = {k: v for k, v in kwargs.items() if k != "accum_dtype"}
194196
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
195197
outputs = self.model(
196198
input_ids=input_ids,
@@ -206,7 +208,7 @@ def lce_forward(
206208
output_hidden_states=output_hidden_states,
207209
return_dict=return_dict,
208210
cache_position=cache_position,
209-
**kwargs,
211+
**model_kwargs,
210212
)
211213

212214
hidden_states = outputs[0]

test/convergence/bf16/test_mini_models_multimodal.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@
557557
"hidden_size": 128, # 1280
558558
"num_heads": 16,
559559
"in_chans": 3,
560+
"out_hidden_size": 1024,
560561
},
561562
attn_implementation="sdpa",
562563
),
@@ -860,7 +861,17 @@ def run_mini_model_multimodal(
860861
for i in range(num_steps):
861862
batch = next(loader_iter).to(model.device)
862863
optimizer.zero_grad()
863-
output = model(**batch, accum_dtype=torch.float32)
864+
supports_accum = getattr(model, "_supports_accum_dtype", None)
865+
if supports_accum is None:
866+
import inspect
867+
868+
params = inspect.signature(model.forward).parameters
869+
supports_accum = ("accum_dtype" in params) or any(
870+
p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
871+
)
872+
setattr(model, "_supports_accum_dtype", supports_accum)
873+
874+
output = model(**batch, accum_dtype=torch.float32) if supports_accum else model(**batch)
864875
output.loss.backward()
865876
optimizer.step()
866877

test/convergence/fp32/test_mini_models_multimodal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@
555555
"hidden_size": 128, # 1280
556556
"num_heads": 16,
557557
"in_chans": 3,
558+
"out_hidden_size": 1024,
558559
},
559560
attn_implementation="sdpa",
560561
),

0 commit comments

Comments
 (0)