-
Notifications
You must be signed in to change notification settings - Fork 54
Expand file tree
/
Copy pathtest_schema_utils.py
More file actions
70 lines (52 loc) · 2.06 KB
/
test_schema_utils.py
File metadata and controls
70 lines (52 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import pytest
import torch
from torch import nn
from merlin.models.torch.utils.module_utils import module_test
from merlin.models.torch.utils.schema_utils import SchemaTrackingMixin
from merlin.schema import Schema, Tags
class TrackedModule(SchemaTrackingMixin, nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.LazyLinear(10)
def forward(self, x: torch.Tensor):
return self.linear(x)
class TrackedDictModule(SchemaTrackingMixin, nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.LazyLinear(10)
def forward(self, x: torch.Tensor):
return {"a": self.linear(x), "b": self.linear(x)}
class TestSchemaTrackingMixin:
def test_tensor(self):
inputs = torch.randn(1, 5)
tracked_module = TrackedModule()
module_test(tracked_module, inputs)
schema = tracked_module.output_schema()
assert isinstance(schema, Schema)
assert len(schema) == 1
assert len(schema.select_by_tag(Tags.EMBEDDING)) == 1
def test_dict(self):
inputs = torch.randn(1, 5)
tracked_module = TrackedDictModule()
outputs = tracked_module(inputs)
traced_outputs = module_test(tracked_module, inputs)
assert torch.equal(outputs["a"], traced_outputs["a"])
assert torch.equal(outputs["b"], traced_outputs["b"])
schema = tracked_module.output_schema()
assert isinstance(schema, Schema)
assert len(schema) == 2
assert len(schema.select_by_tag(Tags.EMBEDDING)) == 2
def test_exception(self):
tracked_module = TrackedModule()
with pytest.raises(RuntimeError):
tracked_module.output_schema()
def test_train(self):
tracked_module = TrackedModule()
tracked_module(torch.randn(1, 5))
tracked_module.train()
assert not tracked_module._forward_called
def test_eval(self):
tracked_module = TrackedModule()
tracked_module(torch.randn(1, 5))
tracked_module.eval()
assert not tracked_module._forward_called