-
Notifications
You must be signed in to change notification settings - Fork 71
Open
Labels
Description
Currently, models in torch_frame.nn have a number of graph breaks, but we should able to remove all or most of them to maximise performance optimisation opportunities. Specifically, the goal is to address as many graph breaks as possible in this test case:
pytorch-frame/test/nn/models/test_compile.py
Lines 18 to 92 in a3b73c4
| @pytest.mark.parametrize( | |
| "model_cls, model_kwargs, stypes, expected_graph_breaks", | |
| [ | |
| pytest.param( | |
| FTTransformer, | |
| dict(channels=8), | |
| None, | |
| 2, | |
| id="FTTransformer", | |
| ), | |
| pytest.param(ResNet, dict(channels=8), None, 2, id="ResNet"), | |
| pytest.param( | |
| TabNet, | |
| dict( | |
| split_feat_channels=2, | |
| split_attn_channels=2, | |
| gamma=0.1, | |
| ), | |
| None, | |
| 7, | |
| id="TabNet", | |
| ), | |
| pytest.param( | |
| TabTransformer, | |
| dict( | |
| channels=8, | |
| num_heads=2, | |
| encoder_pad_size=2, | |
| attn_dropout=0.5, | |
| ffn_dropout=0.5, | |
| ), | |
| None, | |
| 4, | |
| id="TabTransformer", | |
| ), | |
| pytest.param( | |
| Trompt, | |
| dict(channels=8, num_prompts=2), | |
| None, | |
| 16, | |
| id="Trompt", | |
| ), | |
| pytest.param( | |
| ExcelFormer, | |
| dict(in_channels=8, num_cols=3, num_heads=1), | |
| [stype.numerical], | |
| 4, | |
| id="ExcelFormer", | |
| ), | |
| ], | |
| ) | |
| def test_compile_graph_break( | |
| model_cls, | |
| model_kwargs, | |
| stypes, | |
| expected_graph_breaks, | |
| ): | |
| torch._dynamo.config.suppress_errors = True | |
| dataset = FakeDataset( | |
| num_rows=10, | |
| with_nan=False, | |
| stypes=stypes or [stype.categorical, stype.numerical], | |
| ) | |
| dataset.materialize() | |
| tf = dataset.tensor_frame | |
| model = model_cls( | |
| out_channels=1, | |
| num_layers=2, | |
| col_stats=dataset.col_stats, | |
| col_names_dict=tf.col_names_dict, | |
| **model_kwargs, | |
| ) | |
| explanation = torch._dynamo.explain(model)(tf) | |
| assert explanation.graph_break_count <= expected_graph_breaks |
Note
torch._dynamo.explain() doesn't show graph break reasons even when there're graph breaks. Instead, I suggest finding out graph break reasons with torch logs:
TORCH_LOGS=graph_breaks pytest test/nn/models/test_compile.py -k ExcelFormer
Reactions are currently unavailable