Skip to content

Commit 198d775

Browse files
authored
make sure the all of the model is on the same device, so this test will pass on multigpu (axolotl-ai-cloud#2524) [skip ci]
1 parent e4307fb commit 198d775

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model):
144144
def test_geglu_model_integration():
145145
"""Test GeGLU activation with Gemma model."""
146146
model = AutoModelForCausalLM.from_pretrained(
147-
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto"
147+
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0"
148148
)
149149
peft_config = get_peft_config(
150150
{
@@ -347,7 +347,7 @@ def test_model_architecture(model_config):
347347
"""Test LoRA kernel patches across different model architectures."""
348348
# Load model with appropriate dtype
349349
model = AutoModelForCausalLM.from_pretrained(
350-
model_config["name"], torch_dtype=model_config["dtype"], device_map="auto"
350+
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda:0"
351351
)
352352

353353
# Apply LoRA configuration

0 commit comments

Comments
 (0)