99from loguru import logger
1010
1111import ttnn
12+ from models .tt_transformers .tests .test_utils import get_ref_model_dype
13+ from models .tt_transformers .tt .ccl import TT_CCL
1214from models .tt_transformers .tt .mlp import MLP
1315from models .tt_transformers .tt .model_config import ModelArgs
1416from models .utility_functions import comp_allclose , comp_pcc , skip_for_grayskull
3335 "batch_size" ,
3436 (1 ,),
3537)
38+ @pytest .mark .parametrize ("device_params" , [{"fabric_config" : True }], indirect = True )
3639def test_mlp_inference (seq_len , batch_size , mesh_device , reset_seeds , ensure_gc ):
3740 dtype = ttnn .bfloat8_b
3841 mode = "decode" if seq_len <= 32 else "prefill"
3942
40- model_args = ModelArgs (mesh_device , max_batch_size = batch_size , max_seq_len = 128 )
43+ model_args = ModelArgs (mesh_device , max_batch_size = batch_size , max_seq_len = 128 , cache_hf = True )
4144 model_args .n_layers = 1
4245 state_dict = model_args .load_state_dict ()
4346
@@ -50,16 +53,21 @@ def test_mlp_inference(seq_len, batch_size, mesh_device, reset_seeds, ensure_gc)
5053 reference_model = model_args .reference_mlp ()
5154 reference_model .load_state_dict (partial_state_dict )
5255
56+ tt_ccl = TT_CCL (mesh_device )
5357 tt_model = MLP (
5458 mesh_device = mesh_device ,
59+ tt_ccl = tt_ccl ,
5560 args = model_args ,
5661 state_dict = state_dict ,
5762 weight_cache_path = model_args .weight_cache_path (dtype ),
5863 layer_num = 0 ,
5964 dtype = dtype ,
6065 model_config = model_args .get_model_config (),
6166 )
62- torch_input = torch .randn (1 , 1 , seq_len , model_args .dim )
67+
68+ torch_input = torch .randn (
69+ 1 , 1 , seq_len , model_args .dim , dtype = get_ref_model_dype (reference_model , model_args .model_name )
70+ )
6371 reference_output = reference_model (torch_input )
6472 tt_input = ttnn .from_torch (
6573 torch_input ,
0 commit comments