|
| 1 | +"""Tests for the func_config feature. |
| 2 | +
|
| 3 | +Verifies that func_config is correctly extracted for various operation types, |
| 4 | +is empty for source/output nodes, survives the postprocessing pipeline, and |
| 5 | +is accessible on both LayerPassLog and LayerLog. |
| 6 | +""" |
| 7 | + |
| 8 | +import pytest |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | +import torch.nn.functional as F |
| 12 | + |
| 13 | +import torchlens as tl |
| 14 | +from torchlens.capture.salient_args import extract_salient_args, _build_arg_name_map |
| 15 | + |
| 16 | + |
| 17 | +# --------------------------------------------------------------------------- |
| 18 | +# Unit tests for extract_salient_args |
| 19 | +# --------------------------------------------------------------------------- |
| 20 | + |
| 21 | + |
| 22 | +class TestExtractSalientArgs: |
| 23 | + """Unit tests for the extractor registry.""" |
| 24 | + |
| 25 | + def test_conv2d_basic(self): |
| 26 | + weight = torch.randn(16, 3, 3, 3) |
| 27 | + result = extract_salient_args( |
| 28 | + "conv2d", "conv2d", (torch.randn(1, 3, 8, 8), weight), {}, [(16, 3, 3, 3)] |
| 29 | + ) |
| 30 | + assert result["out_channels"] == 16 |
| 31 | + assert result["in_channels"] == 3 |
| 32 | + assert result["kernel_size"] == (3, 3) |
| 33 | + assert "stride" not in result # default stride=1 suppressed |
| 34 | + assert "dilation" not in result # default dilation=1 suppressed |
| 35 | + |
| 36 | + def test_conv2d_with_stride(self): |
| 37 | + result = extract_salient_args( |
| 38 | + "conv2d", |
| 39 | + "conv2d", |
| 40 | + (torch.randn(1, 3, 8, 8), torch.randn(16, 3, 3, 3)), |
| 41 | + {"stride": (2, 2)}, |
| 42 | + [(16, 3, 3, 3)], |
| 43 | + ) |
| 44 | + assert result["stride"] == (2, 2) |
| 45 | + |
| 46 | + def test_conv2d_with_groups(self): |
| 47 | + result = extract_salient_args( |
| 48 | + "conv2d", |
| 49 | + "conv2d", |
| 50 | + (torch.randn(1, 4, 8, 8), torch.randn(4, 1, 3, 3)), |
| 51 | + {"groups": 4}, |
| 52 | + [(4, 1, 3, 3)], |
| 53 | + ) |
| 54 | + assert result["groups"] == 4 |
| 55 | + |
| 56 | + def test_linear(self): |
| 57 | + result = extract_salient_args( |
| 58 | + "linear", "linear", (torch.randn(1, 128), torch.randn(64, 128)), {}, [(64, 128)] |
| 59 | + ) |
| 60 | + assert result["out_features"] == 64 |
| 61 | + assert result["in_features"] == 128 |
| 62 | + |
| 63 | + def test_dropout(self): |
| 64 | + result = extract_salient_args("dropout", "dropout", (torch.randn(1, 10),), {"p": 0.3}, []) |
| 65 | + assert result["p"] == 0.3 |
| 66 | + |
| 67 | + def test_batch_norm(self): |
| 68 | + result = extract_salient_args( |
| 69 | + "batchnorm", |
| 70 | + "batch_norm", |
| 71 | + (torch.randn(1, 16, 4, 4),), |
| 72 | + {"eps": 1e-05, "momentum": 0.1}, |
| 73 | + [], |
| 74 | + ) |
| 75 | + assert result["eps"] == 1e-05 |
| 76 | + assert result["momentum"] == 0.1 |
| 77 | + |
| 78 | + def test_layer_norm(self): |
| 79 | + result = extract_salient_args( |
| 80 | + "layernorm", "layer_norm", (torch.randn(1, 10),), {"normalized_shape": (10,)}, [] |
| 81 | + ) |
| 82 | + assert result["normalized_shape"] == (10,) |
| 83 | + |
| 84 | + def test_max_pool(self): |
| 85 | + result = extract_salient_args( |
| 86 | + "maxpool2d", |
| 87 | + "max_pool2d", |
| 88 | + (torch.randn(1, 3, 8, 8),), |
| 89 | + {"kernel_size": 2, "stride": 2}, |
| 90 | + [], |
| 91 | + ) |
| 92 | + assert result["kernel_size"] == 2 |
| 93 | + assert result["stride"] == 2 |
| 94 | + |
| 95 | + def test_adaptive_avg_pool(self): |
| 96 | + result = extract_salient_args( |
| 97 | + "adaptiveavgpool2d", |
| 98 | + "adaptive_avg_pool2d", |
| 99 | + (torch.randn(1, 3, 8, 8),), |
| 100 | + {"output_size": (1, 1)}, |
| 101 | + [], |
| 102 | + ) |
| 103 | + assert result["output_size"] == (1, 1) |
| 104 | + |
| 105 | + def test_softmax(self): |
| 106 | + result = extract_salient_args("softmax", "softmax", (torch.randn(1, 10),), {"dim": 1}, []) |
| 107 | + assert result["dim"] == 1 |
| 108 | + |
| 109 | + def test_cat(self): |
| 110 | + t = torch.randn(1, 3) |
| 111 | + result = extract_salient_args("cat", "cat", ([t, t],), {"dim": 1}, []) |
| 112 | + assert result["dim"] == 1 |
| 113 | + |
| 114 | + def test_reduction_mean(self): |
| 115 | + result = extract_salient_args( |
| 116 | + "mean", "mean", (torch.randn(2, 3, 4),), {"dim": (1, 2), "keepdim": True}, [] |
| 117 | + ) |
| 118 | + assert result["dim"] == (1, 2) |
| 119 | + assert result["keepdim"] is True |
| 120 | + |
| 121 | + def test_clamp(self): |
| 122 | + result = extract_salient_args( |
| 123 | + "clamp", "clamp", (torch.randn(3),), {"min": 0.0, "max": 1.0}, [] |
| 124 | + ) |
| 125 | + assert result["min"] == 0.0 |
| 126 | + assert result["max"] == 1.0 |
| 127 | + |
| 128 | + def test_transpose(self): |
| 129 | + result = extract_salient_args("transpose", "transpose", (torch.randn(2, 3), 0, 1), {}, []) |
| 130 | + assert result["dim0"] == 0 |
| 131 | + assert result["dim1"] == 1 |
| 132 | + |
| 133 | + def test_leaky_relu(self): |
| 134 | + result = extract_salient_args( |
| 135 | + "leakyrelu", "leaky_relu", (torch.randn(3),), {"negative_slope": 0.2}, [] |
| 136 | + ) |
| 137 | + assert result["negative_slope"] == 0.2 |
| 138 | + |
| 139 | + def test_embedding(self): |
| 140 | + result = extract_salient_args( |
| 141 | + "embedding", |
| 142 | + "embedding", |
| 143 | + (torch.tensor([0, 1, 2]), torch.randn(100, 64)), |
| 144 | + {}, |
| 145 | + [(100, 64)], |
| 146 | + ) |
| 147 | + assert result["num_embeddings"] == 100 |
| 148 | + assert result["embedding_dim"] == 64 |
| 149 | + |
| 150 | + def test_interpolate(self): |
| 151 | + result = extract_salient_args( |
| 152 | + "interpolate", |
| 153 | + "interpolate", |
| 154 | + (torch.randn(1, 1, 4, 4),), |
| 155 | + {"scale_factor": 2.0, "mode": "bilinear"}, |
| 156 | + [], |
| 157 | + ) |
| 158 | + assert result["scale_factor"] == 2.0 |
| 159 | + assert result["mode"] == "bilinear" |
| 160 | + |
| 161 | + def test_unregistered_op_returns_empty(self): |
| 162 | + result = extract_salient_args("relu", "relu", (torch.randn(3),), {}, []) |
| 163 | + assert result == {} |
| 164 | + |
| 165 | + def test_never_contains_tensors(self): |
| 166 | + """Values must be simple Python types, never tensors.""" |
| 167 | + result = extract_salient_args("softmax", "softmax", (torch.randn(1, 10),), {"dim": 1}, []) |
| 168 | + for v in result.values(): |
| 169 | + assert not isinstance(v, torch.Tensor) |
| 170 | + |
| 171 | + def test_sdpa(self): |
| 172 | + result = extract_salient_args( |
| 173 | + "scaleddotproductattention", |
| 174 | + "scaled_dot_product_attention", |
| 175 | + (torch.randn(1, 4, 8, 16), torch.randn(1, 4, 8, 16), torch.randn(1, 4, 8, 16)), |
| 176 | + {"dropout_p": 0.1, "is_causal": True}, |
| 177 | + [], |
| 178 | + ) |
| 179 | + assert result["dropout_p"] == 0.1 |
| 180 | + assert result["is_causal"] is True |
| 181 | + |
| 182 | + |
| 183 | +class TestBuildArgNameMap: |
| 184 | + """Unit tests for arg name mapping helper.""" |
| 185 | + |
| 186 | + def test_basic_mapping(self): |
| 187 | + from torchlens import _state |
| 188 | + |
| 189 | + _state._func_argnames["softmax"] = ("input", "dim", "dtype") |
| 190 | + result = _build_arg_name_map("softmax", (torch.randn(3), 1), {}) |
| 191 | + assert result["dim"] == 1 |
| 192 | + |
| 193 | + def test_kwargs_take_precedence(self): |
| 194 | + from torchlens import _state |
| 195 | + |
| 196 | + _state._func_argnames["softmax"] = ("input", "dim", "dtype") |
| 197 | + result = _build_arg_name_map("softmax", (torch.randn(3), 0), {"dim": 1}) |
| 198 | + assert result["dim"] == 1 |
| 199 | + |
| 200 | + |
| 201 | +# --------------------------------------------------------------------------- |
| 202 | +# Integration tests: end-to-end with log_forward_pass |
| 203 | +# --------------------------------------------------------------------------- |
| 204 | + |
| 205 | + |
| 206 | +class TestFuncConfigIntegration: |
| 207 | + """Integration tests verifying func_config through the full pipeline.""" |
| 208 | + |
| 209 | + def test_conv_bn_linear_model(self): |
| 210 | + """Conv2d, BatchNorm, Linear all populate func_config correctly.""" |
| 211 | + |
| 212 | + class Model(nn.Module): |
| 213 | + def __init__(self): |
| 214 | + super().__init__() |
| 215 | + self.conv = nn.Conv2d(3, 16, 3, stride=2, padding=1) |
| 216 | + self.bn = nn.BatchNorm2d(16) |
| 217 | + self.fc = nn.Linear(16, 10) |
| 218 | + |
| 219 | + def forward(self, x): |
| 220 | + x = self.conv(x) |
| 221 | + x = self.bn(x) |
| 222 | + x = torch.relu(x) |
| 223 | + x = x.mean(dim=[2, 3]) |
| 224 | + x = self.fc(x) |
| 225 | + return x |
| 226 | + |
| 227 | + model = Model() |
| 228 | + log = tl.log_forward_pass(model, torch.randn(1, 3, 8, 8)) |
| 229 | + |
| 230 | + # Find layers by type |
| 231 | + conv_layer = next(ly for ly in log.layers if ly.layer_type == "conv2d") |
| 232 | + assert conv_layer.func_config["out_channels"] == 16 |
| 233 | + assert conv_layer.func_config["in_channels"] == 3 |
| 234 | + assert conv_layer.func_config["stride"] == (2, 2) |
| 235 | + assert conv_layer.func_config["padding"] == (1, 1) |
| 236 | + |
| 237 | + bn_layer = next(ly for ly in log.layers if ly.layer_type == "batchnorm") |
| 238 | + assert "eps" in bn_layer.func_config |
| 239 | + |
| 240 | + linear_layer = next(ly for ly in log.layers if ly.layer_type == "linear") |
| 241 | + assert linear_layer.func_config["out_features"] == 10 |
| 242 | + assert linear_layer.func_config["in_features"] == 16 |
| 243 | + |
| 244 | + def test_source_tensors_have_empty_func_config(self): |
| 245 | + """Input and buffer layers should have func_config == {}.""" |
| 246 | + model = nn.BatchNorm2d(3) |
| 247 | + log = tl.log_forward_pass(model, torch.randn(1, 3, 4, 4)) |
| 248 | + |
| 249 | + for layer in log.layers: |
| 250 | + if layer.is_input_layer or layer.is_buffer_layer: |
| 251 | + assert layer.func_config == {}, ( |
| 252 | + f"Source layer {layer.layer_label} has non-empty func_config: " |
| 253 | + f"{layer.func_config}" |
| 254 | + ) |
| 255 | + |
| 256 | + def test_output_nodes_have_empty_func_config(self): |
| 257 | + """Synthetic output nodes should have func_config == {}.""" |
| 258 | + model = nn.Linear(10, 5) |
| 259 | + log = tl.log_forward_pass(model, torch.randn(1, 10)) |
| 260 | + |
| 261 | + for layer in log.layers: |
| 262 | + if layer.is_output_layer: |
| 263 | + assert layer.func_config == {}, ( |
| 264 | + f"Output layer {layer.layer_label} has non-empty func_config" |
| 265 | + ) |
| 266 | + |
| 267 | + def test_func_config_on_layer_pass_log(self): |
| 268 | + """func_config should be accessible on LayerPassLog (per-pass) objects.""" |
| 269 | + model = nn.Linear(10, 5) |
| 270 | + log = tl.log_forward_pass(model, torch.randn(1, 10)) |
| 271 | + |
| 272 | + linear_layer = next(ly for ly in log.layers if ly.layer_type == "linear") |
| 273 | + # Access via pass |
| 274 | + pass_log = linear_layer.passes[1] |
| 275 | + assert pass_log.func_config["out_features"] == 5 |
| 276 | + |
| 277 | + def test_func_config_in_str_output(self): |
| 278 | + """func_config should appear in the string representation when non-empty.""" |
| 279 | + model = nn.Linear(10, 5) |
| 280 | + log = tl.log_forward_pass(model, torch.randn(1, 10)) |
| 281 | + |
| 282 | + linear_layer = next(ly for ly in log.layers if ly.layer_type == "linear") |
| 283 | + s = str(linear_layer) |
| 284 | + assert "Config:" in s |
| 285 | + assert "out_features=5" in s |
| 286 | + |
| 287 | + def test_func_config_not_in_str_when_empty(self): |
| 288 | + """Layers with no func_config should not show the config line.""" |
| 289 | + model = nn.ReLU() |
| 290 | + log = tl.log_forward_pass(model, torch.randn(1, 10)) |
| 291 | + |
| 292 | + relu_layer = next(ly for ly in log.layers if ly.layer_type == "relu") |
| 293 | + s = str(relu_layer) |
| 294 | + assert "Config:" not in s |
| 295 | + |
| 296 | + def test_dropout_func_config(self): |
| 297 | + """Dropout should capture the p parameter.""" |
| 298 | + |
| 299 | + class Model(nn.Module): |
| 300 | + def __init__(self): |
| 301 | + super().__init__() |
| 302 | + self.drop = nn.Dropout(0.3) |
| 303 | + |
| 304 | + def forward(self, x): |
| 305 | + return self.drop(x) |
| 306 | + |
| 307 | + model = Model() |
| 308 | + log = tl.log_forward_pass(model, torch.randn(1, 10)) |
| 309 | + |
| 310 | + dropout_layer = next(ly for ly in log.layers if ly.layer_type == "dropout") |
| 311 | + assert dropout_layer.func_config["p"] == 0.3 |
| 312 | + |
| 313 | + def test_reduction_func_config(self): |
| 314 | + """Reduction ops should capture dim and keepdim.""" |
| 315 | + |
| 316 | + class Model(nn.Module): |
| 317 | + def forward(self, x): |
| 318 | + return x.sum(dim=1, keepdim=True) |
| 319 | + |
| 320 | + log = tl.log_forward_pass(Model(), torch.randn(2, 3, 4)) |
| 321 | + sum_layer = next(ly for ly in log.layers if ly.layer_type == "sum") |
| 322 | + assert sum_layer.func_config["dim"] == 1 |
| 323 | + assert sum_layer.func_config["keepdim"] is True |
| 324 | + |
| 325 | + def test_pooling_func_config(self): |
| 326 | + """Pooling ops should capture kernel_size and stride.""" |
| 327 | + |
| 328 | + class Model(nn.Module): |
| 329 | + def __init__(self): |
| 330 | + super().__init__() |
| 331 | + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
| 332 | + |
| 333 | + def forward(self, x): |
| 334 | + return self.pool(x) |
| 335 | + |
| 336 | + log = tl.log_forward_pass(Model(), torch.randn(1, 3, 8, 8)) |
| 337 | + pool_layer = next(ly for ly in log.layers if "maxpool" in ly.layer_type) |
| 338 | + assert pool_layer.func_config["kernel_size"] == 2 |
| 339 | + assert pool_layer.func_config["stride"] == 2 |
| 340 | + |
| 341 | + def test_save_new_activations_preserves_func_config(self): |
| 342 | + """func_config should survive save_new_activations (fast path).""" |
| 343 | + model = nn.Linear(10, 5) |
| 344 | + x1 = torch.randn(1, 10) |
| 345 | + log = tl.log_forward_pass(model, x1, layers_to_save="all") |
| 346 | + |
| 347 | + # Run with new input via ModelLog method |
| 348 | + x2 = torch.randn(1, 10) |
| 349 | + log.save_new_activations(model, x2) |
| 350 | + |
| 351 | + linear_layer = next(ly for ly in log.layers if ly.layer_type == "linear") |
| 352 | + assert linear_layer.func_config["out_features"] == 5 |
| 353 | + assert linear_layer.func_config["in_features"] == 10 |
| 354 | + |
| 355 | + def test_conv_default_stride_not_shown(self): |
| 356 | + """Conv with default stride/padding/dilation should not include them.""" |
| 357 | + |
| 358 | + class Model(nn.Module): |
| 359 | + def __init__(self): |
| 360 | + super().__init__() |
| 361 | + self.conv = nn.Conv2d(3, 16, 3) |
| 362 | + |
| 363 | + def forward(self, x): |
| 364 | + return self.conv(x) |
| 365 | + |
| 366 | + log = tl.log_forward_pass(Model(), torch.randn(1, 3, 8, 8)) |
| 367 | + conv_layer = next(ly for ly in log.layers if ly.layer_type == "conv2d") |
| 368 | + assert "stride" not in conv_layer.func_config |
| 369 | + assert "padding" not in conv_layer.func_config |
| 370 | + assert "dilation" not in conv_layer.func_config |
| 371 | + assert conv_layer.func_config["kernel_size"] == (3, 3) |
| 372 | + |
| 373 | + def test_all_layers_have_func_config_attribute(self): |
| 374 | + """Every layer in the log should have a func_config attribute (dict).""" |
| 375 | + |
| 376 | + class Model(nn.Module): |
| 377 | + def __init__(self): |
| 378 | + super().__init__() |
| 379 | + self.conv = nn.Conv2d(3, 16, 3) |
| 380 | + self.bn = nn.BatchNorm2d(16) |
| 381 | + self.fc = nn.Linear(16 * 6 * 6, 10) |
| 382 | + |
| 383 | + def forward(self, x): |
| 384 | + x = torch.relu(self.conv(x)) |
| 385 | + x = self.bn(x) |
| 386 | + x = x.view(x.size(0), -1) |
| 387 | + x = self.fc(x) |
| 388 | + return x |
| 389 | + |
| 390 | + log = tl.log_forward_pass(Model(), torch.randn(1, 3, 8, 8)) |
| 391 | + for layer in log.layers: |
| 392 | + assert hasattr(layer, "func_config"), f"Missing func_config on {layer.layer_label}" |
| 393 | + assert isinstance(layer.func_config, dict), ( |
| 394 | + f"func_config should be dict on {layer.layer_label}, got {type(layer.func_config)}" |
| 395 | + ) |
0 commit comments