Skip to content

Commit 018094d

Browse files
Merge pull request #122 from johnmarktaylor91/feat/func-config
2 parents d02233b + 7144d6d commit 018094d

File tree

10 files changed

+891
-3
lines changed

10 files changed

+891
-3
lines changed

tests/test_func_config.py

Lines changed: 395 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,395 @@
1+
"""Tests for the func_config feature.
2+
3+
Verifies that func_config is correctly extracted for various operation types,
4+
is empty for source/output nodes, survives the postprocessing pipeline, and
5+
is accessible on both LayerPassLog and LayerLog.
6+
"""
7+
8+
import pytest
9+
import torch
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
13+
import torchlens as tl
14+
from torchlens.capture.salient_args import extract_salient_args, _build_arg_name_map
15+
16+
17+
# ---------------------------------------------------------------------------
18+
# Unit tests for extract_salient_args
19+
# ---------------------------------------------------------------------------
20+
21+
22+
class TestExtractSalientArgs:
23+
"""Unit tests for the extractor registry."""
24+
25+
def test_conv2d_basic(self):
26+
weight = torch.randn(16, 3, 3, 3)
27+
result = extract_salient_args(
28+
"conv2d", "conv2d", (torch.randn(1, 3, 8, 8), weight), {}, [(16, 3, 3, 3)]
29+
)
30+
assert result["out_channels"] == 16
31+
assert result["in_channels"] == 3
32+
assert result["kernel_size"] == (3, 3)
33+
assert "stride" not in result # default stride=1 suppressed
34+
assert "dilation" not in result # default dilation=1 suppressed
35+
36+
def test_conv2d_with_stride(self):
37+
result = extract_salient_args(
38+
"conv2d",
39+
"conv2d",
40+
(torch.randn(1, 3, 8, 8), torch.randn(16, 3, 3, 3)),
41+
{"stride": (2, 2)},
42+
[(16, 3, 3, 3)],
43+
)
44+
assert result["stride"] == (2, 2)
45+
46+
def test_conv2d_with_groups(self):
47+
result = extract_salient_args(
48+
"conv2d",
49+
"conv2d",
50+
(torch.randn(1, 4, 8, 8), torch.randn(4, 1, 3, 3)),
51+
{"groups": 4},
52+
[(4, 1, 3, 3)],
53+
)
54+
assert result["groups"] == 4
55+
56+
def test_linear(self):
57+
result = extract_salient_args(
58+
"linear", "linear", (torch.randn(1, 128), torch.randn(64, 128)), {}, [(64, 128)]
59+
)
60+
assert result["out_features"] == 64
61+
assert result["in_features"] == 128
62+
63+
def test_dropout(self):
64+
result = extract_salient_args("dropout", "dropout", (torch.randn(1, 10),), {"p": 0.3}, [])
65+
assert result["p"] == 0.3
66+
67+
def test_batch_norm(self):
68+
result = extract_salient_args(
69+
"batchnorm",
70+
"batch_norm",
71+
(torch.randn(1, 16, 4, 4),),
72+
{"eps": 1e-05, "momentum": 0.1},
73+
[],
74+
)
75+
assert result["eps"] == 1e-05
76+
assert result["momentum"] == 0.1
77+
78+
def test_layer_norm(self):
79+
result = extract_salient_args(
80+
"layernorm", "layer_norm", (torch.randn(1, 10),), {"normalized_shape": (10,)}, []
81+
)
82+
assert result["normalized_shape"] == (10,)
83+
84+
def test_max_pool(self):
85+
result = extract_salient_args(
86+
"maxpool2d",
87+
"max_pool2d",
88+
(torch.randn(1, 3, 8, 8),),
89+
{"kernel_size": 2, "stride": 2},
90+
[],
91+
)
92+
assert result["kernel_size"] == 2
93+
assert result["stride"] == 2
94+
95+
def test_adaptive_avg_pool(self):
96+
result = extract_salient_args(
97+
"adaptiveavgpool2d",
98+
"adaptive_avg_pool2d",
99+
(torch.randn(1, 3, 8, 8),),
100+
{"output_size": (1, 1)},
101+
[],
102+
)
103+
assert result["output_size"] == (1, 1)
104+
105+
def test_softmax(self):
106+
result = extract_salient_args("softmax", "softmax", (torch.randn(1, 10),), {"dim": 1}, [])
107+
assert result["dim"] == 1
108+
109+
def test_cat(self):
110+
t = torch.randn(1, 3)
111+
result = extract_salient_args("cat", "cat", ([t, t],), {"dim": 1}, [])
112+
assert result["dim"] == 1
113+
114+
def test_reduction_mean(self):
115+
result = extract_salient_args(
116+
"mean", "mean", (torch.randn(2, 3, 4),), {"dim": (1, 2), "keepdim": True}, []
117+
)
118+
assert result["dim"] == (1, 2)
119+
assert result["keepdim"] is True
120+
121+
def test_clamp(self):
122+
result = extract_salient_args(
123+
"clamp", "clamp", (torch.randn(3),), {"min": 0.0, "max": 1.0}, []
124+
)
125+
assert result["min"] == 0.0
126+
assert result["max"] == 1.0
127+
128+
def test_transpose(self):
129+
result = extract_salient_args("transpose", "transpose", (torch.randn(2, 3), 0, 1), {}, [])
130+
assert result["dim0"] == 0
131+
assert result["dim1"] == 1
132+
133+
def test_leaky_relu(self):
134+
result = extract_salient_args(
135+
"leakyrelu", "leaky_relu", (torch.randn(3),), {"negative_slope": 0.2}, []
136+
)
137+
assert result["negative_slope"] == 0.2
138+
139+
def test_embedding(self):
140+
result = extract_salient_args(
141+
"embedding",
142+
"embedding",
143+
(torch.tensor([0, 1, 2]), torch.randn(100, 64)),
144+
{},
145+
[(100, 64)],
146+
)
147+
assert result["num_embeddings"] == 100
148+
assert result["embedding_dim"] == 64
149+
150+
def test_interpolate(self):
151+
result = extract_salient_args(
152+
"interpolate",
153+
"interpolate",
154+
(torch.randn(1, 1, 4, 4),),
155+
{"scale_factor": 2.0, "mode": "bilinear"},
156+
[],
157+
)
158+
assert result["scale_factor"] == 2.0
159+
assert result["mode"] == "bilinear"
160+
161+
def test_unregistered_op_returns_empty(self):
162+
result = extract_salient_args("relu", "relu", (torch.randn(3),), {}, [])
163+
assert result == {}
164+
165+
def test_never_contains_tensors(self):
166+
"""Values must be simple Python types, never tensors."""
167+
result = extract_salient_args("softmax", "softmax", (torch.randn(1, 10),), {"dim": 1}, [])
168+
for v in result.values():
169+
assert not isinstance(v, torch.Tensor)
170+
171+
def test_sdpa(self):
172+
result = extract_salient_args(
173+
"scaleddotproductattention",
174+
"scaled_dot_product_attention",
175+
(torch.randn(1, 4, 8, 16), torch.randn(1, 4, 8, 16), torch.randn(1, 4, 8, 16)),
176+
{"dropout_p": 0.1, "is_causal": True},
177+
[],
178+
)
179+
assert result["dropout_p"] == 0.1
180+
assert result["is_causal"] is True
181+
182+
183+
class TestBuildArgNameMap:
184+
"""Unit tests for arg name mapping helper."""
185+
186+
def test_basic_mapping(self):
187+
from torchlens import _state
188+
189+
_state._func_argnames["softmax"] = ("input", "dim", "dtype")
190+
result = _build_arg_name_map("softmax", (torch.randn(3), 1), {})
191+
assert result["dim"] == 1
192+
193+
def test_kwargs_take_precedence(self):
194+
from torchlens import _state
195+
196+
_state._func_argnames["softmax"] = ("input", "dim", "dtype")
197+
result = _build_arg_name_map("softmax", (torch.randn(3), 0), {"dim": 1})
198+
assert result["dim"] == 1
199+
200+
201+
# ---------------------------------------------------------------------------
202+
# Integration tests: end-to-end with log_forward_pass
203+
# ---------------------------------------------------------------------------
204+
205+
206+
class TestFuncConfigIntegration:
207+
"""Integration tests verifying func_config through the full pipeline."""
208+
209+
def test_conv_bn_linear_model(self):
210+
"""Conv2d, BatchNorm, Linear all populate func_config correctly."""
211+
212+
class Model(nn.Module):
213+
def __init__(self):
214+
super().__init__()
215+
self.conv = nn.Conv2d(3, 16, 3, stride=2, padding=1)
216+
self.bn = nn.BatchNorm2d(16)
217+
self.fc = nn.Linear(16, 10)
218+
219+
def forward(self, x):
220+
x = self.conv(x)
221+
x = self.bn(x)
222+
x = torch.relu(x)
223+
x = x.mean(dim=[2, 3])
224+
x = self.fc(x)
225+
return x
226+
227+
model = Model()
228+
log = tl.log_forward_pass(model, torch.randn(1, 3, 8, 8))
229+
230+
# Find layers by type
231+
conv_layer = next(ly for ly in log.layers if ly.layer_type == "conv2d")
232+
assert conv_layer.func_config["out_channels"] == 16
233+
assert conv_layer.func_config["in_channels"] == 3
234+
assert conv_layer.func_config["stride"] == (2, 2)
235+
assert conv_layer.func_config["padding"] == (1, 1)
236+
237+
bn_layer = next(ly for ly in log.layers if ly.layer_type == "batchnorm")
238+
assert "eps" in bn_layer.func_config
239+
240+
linear_layer = next(ly for ly in log.layers if ly.layer_type == "linear")
241+
assert linear_layer.func_config["out_features"] == 10
242+
assert linear_layer.func_config["in_features"] == 16
243+
244+
def test_source_tensors_have_empty_func_config(self):
245+
"""Input and buffer layers should have func_config == {}."""
246+
model = nn.BatchNorm2d(3)
247+
log = tl.log_forward_pass(model, torch.randn(1, 3, 4, 4))
248+
249+
for layer in log.layers:
250+
if layer.is_input_layer or layer.is_buffer_layer:
251+
assert layer.func_config == {}, (
252+
f"Source layer {layer.layer_label} has non-empty func_config: "
253+
f"{layer.func_config}"
254+
)
255+
256+
def test_output_nodes_have_empty_func_config(self):
257+
"""Synthetic output nodes should have func_config == {}."""
258+
model = nn.Linear(10, 5)
259+
log = tl.log_forward_pass(model, torch.randn(1, 10))
260+
261+
for layer in log.layers:
262+
if layer.is_output_layer:
263+
assert layer.func_config == {}, (
264+
f"Output layer {layer.layer_label} has non-empty func_config"
265+
)
266+
267+
def test_func_config_on_layer_pass_log(self):
268+
"""func_config should be accessible on LayerPassLog (per-pass) objects."""
269+
model = nn.Linear(10, 5)
270+
log = tl.log_forward_pass(model, torch.randn(1, 10))
271+
272+
linear_layer = next(ly for ly in log.layers if ly.layer_type == "linear")
273+
# Access via pass
274+
pass_log = linear_layer.passes[1]
275+
assert pass_log.func_config["out_features"] == 5
276+
277+
def test_func_config_in_str_output(self):
278+
"""func_config should appear in the string representation when non-empty."""
279+
model = nn.Linear(10, 5)
280+
log = tl.log_forward_pass(model, torch.randn(1, 10))
281+
282+
linear_layer = next(ly for ly in log.layers if ly.layer_type == "linear")
283+
s = str(linear_layer)
284+
assert "Config:" in s
285+
assert "out_features=5" in s
286+
287+
def test_func_config_not_in_str_when_empty(self):
288+
"""Layers with no func_config should not show the config line."""
289+
model = nn.ReLU()
290+
log = tl.log_forward_pass(model, torch.randn(1, 10))
291+
292+
relu_layer = next(ly for ly in log.layers if ly.layer_type == "relu")
293+
s = str(relu_layer)
294+
assert "Config:" not in s
295+
296+
def test_dropout_func_config(self):
297+
"""Dropout should capture the p parameter."""
298+
299+
class Model(nn.Module):
300+
def __init__(self):
301+
super().__init__()
302+
self.drop = nn.Dropout(0.3)
303+
304+
def forward(self, x):
305+
return self.drop(x)
306+
307+
model = Model()
308+
log = tl.log_forward_pass(model, torch.randn(1, 10))
309+
310+
dropout_layer = next(ly for ly in log.layers if ly.layer_type == "dropout")
311+
assert dropout_layer.func_config["p"] == 0.3
312+
313+
def test_reduction_func_config(self):
314+
"""Reduction ops should capture dim and keepdim."""
315+
316+
class Model(nn.Module):
317+
def forward(self, x):
318+
return x.sum(dim=1, keepdim=True)
319+
320+
log = tl.log_forward_pass(Model(), torch.randn(2, 3, 4))
321+
sum_layer = next(ly for ly in log.layers if ly.layer_type == "sum")
322+
assert sum_layer.func_config["dim"] == 1
323+
assert sum_layer.func_config["keepdim"] is True
324+
325+
def test_pooling_func_config(self):
326+
"""Pooling ops should capture kernel_size and stride."""
327+
328+
class Model(nn.Module):
329+
def __init__(self):
330+
super().__init__()
331+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
332+
333+
def forward(self, x):
334+
return self.pool(x)
335+
336+
log = tl.log_forward_pass(Model(), torch.randn(1, 3, 8, 8))
337+
pool_layer = next(ly for ly in log.layers if "maxpool" in ly.layer_type)
338+
assert pool_layer.func_config["kernel_size"] == 2
339+
assert pool_layer.func_config["stride"] == 2
340+
341+
def test_save_new_activations_preserves_func_config(self):
342+
"""func_config should survive save_new_activations (fast path)."""
343+
model = nn.Linear(10, 5)
344+
x1 = torch.randn(1, 10)
345+
log = tl.log_forward_pass(model, x1, layers_to_save="all")
346+
347+
# Run with new input via ModelLog method
348+
x2 = torch.randn(1, 10)
349+
log.save_new_activations(model, x2)
350+
351+
linear_layer = next(ly for ly in log.layers if ly.layer_type == "linear")
352+
assert linear_layer.func_config["out_features"] == 5
353+
assert linear_layer.func_config["in_features"] == 10
354+
355+
def test_conv_default_stride_not_shown(self):
356+
"""Conv with default stride/padding/dilation should not include them."""
357+
358+
class Model(nn.Module):
359+
def __init__(self):
360+
super().__init__()
361+
self.conv = nn.Conv2d(3, 16, 3)
362+
363+
def forward(self, x):
364+
return self.conv(x)
365+
366+
log = tl.log_forward_pass(Model(), torch.randn(1, 3, 8, 8))
367+
conv_layer = next(ly for ly in log.layers if ly.layer_type == "conv2d")
368+
assert "stride" not in conv_layer.func_config
369+
assert "padding" not in conv_layer.func_config
370+
assert "dilation" not in conv_layer.func_config
371+
assert conv_layer.func_config["kernel_size"] == (3, 3)
372+
373+
def test_all_layers_have_func_config_attribute(self):
374+
"""Every layer in the log should have a func_config attribute (dict)."""
375+
376+
class Model(nn.Module):
377+
def __init__(self):
378+
super().__init__()
379+
self.conv = nn.Conv2d(3, 16, 3)
380+
self.bn = nn.BatchNorm2d(16)
381+
self.fc = nn.Linear(16 * 6 * 6, 10)
382+
383+
def forward(self, x):
384+
x = torch.relu(self.conv(x))
385+
x = self.bn(x)
386+
x = x.view(x.size(0), -1)
387+
x = self.fc(x)
388+
return x
389+
390+
log = tl.log_forward_pass(Model(), torch.randn(1, 3, 8, 8))
391+
for layer in log.layers:
392+
assert hasattr(layer, "func_config"), f"Missing func_config on {layer.layer_label}"
393+
assert isinstance(layer.func_config, dict), (
394+
f"func_config should be dict on {layer.layer_label}, got {type(layer.func_config)}"
395+
)

0 commit comments

Comments
 (0)