Skip to content

Commit ab29083

Browse files
committed
Reverting changes to test_mlp
1 parent 36e295c commit ab29083

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

models/tt_transformers/tests/test_mlp.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from loguru import logger
1010

1111
import ttnn
12+
from models.tt_transformers.tests.test_utils import get_ref_model_dype
13+
from models.tt_transformers.tt.ccl import TT_CCL
1214
from models.tt_transformers.tt.mlp import MLP
1315
from models.tt_transformers.tt.model_config import ModelArgs
1416
from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull
@@ -33,11 +35,12 @@
3335
"batch_size",
3436
(1,),
3537
)
38+
@pytest.mark.parametrize("device_params", [{"fabric_config": True}], indirect=True)
3639
def test_mlp_inference(seq_len, batch_size, mesh_device, reset_seeds, ensure_gc):
3740
dtype = ttnn.bfloat8_b
3841
mode = "decode" if seq_len <= 32 else "prefill"
3942

40-
model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=128)
43+
model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=128, cache_hf=True)
4144
model_args.n_layers = 1
4245
state_dict = model_args.load_state_dict()
4346

@@ -50,16 +53,21 @@ def test_mlp_inference(seq_len, batch_size, mesh_device, reset_seeds, ensure_gc)
5053
reference_model = model_args.reference_mlp()
5154
reference_model.load_state_dict(partial_state_dict)
5255

56+
tt_ccl = TT_CCL(mesh_device)
5357
tt_model = MLP(
5458
mesh_device=mesh_device,
59+
tt_ccl=tt_ccl,
5560
args=model_args,
5661
state_dict=state_dict,
5762
weight_cache_path=model_args.weight_cache_path(dtype),
5863
layer_num=0,
5964
dtype=dtype,
6065
model_config=model_args.get_model_config(),
6166
)
62-
torch_input = torch.randn(1, 1, seq_len, model_args.dim)
67+
68+
torch_input = torch.randn(
69+
1, 1, seq_len, model_args.dim, dtype=get_ref_model_dype(reference_model, model_args.model_name)
70+
)
6371
reference_output = reference_model(torch_input)
6472
tt_input = ttnn.from_torch(
6573
torch_input,

0 commit comments

Comments
 (0)