File tree Expand file tree Collapse file tree 9 files changed +8
-8
lines changed Expand file tree Collapse file tree 9 files changed +8
-8
lines changed Original file line number Diff line number Diff line change 44
44
from torchrec .distributed .comm import get_local_size
45
45
from torchrec .distributed .embedding_types import EmbeddingComputeKernel
46
46
from torchrec .distributed .planner import Topology
47
+ from torchrec .distributed .test_utils .model_input import ModelInput
47
48
48
49
from torchrec .distributed .test_utils .multi_process import (
49
50
MultiProcessContext ,
50
51
run_multi_process_func ,
51
52
)
53
+ from torchrec .distributed .test_utils .pipeline_config import PipelineConfig
52
54
from torchrec .distributed .test_utils .table_config import EmbeddingTablesConfig
53
- from torchrec .distributed .test_utils .test_input import ModelInput
54
55
from torchrec .distributed .test_utils .test_model import TestOverArchLarge
55
- from torchrec .distributed .test_utils .train_pipeline import PipelineConfig
56
56
from torchrec .distributed .train_pipeline import TrainPipeline
57
57
from torchrec .distributed .types import ShardingType
58
58
from torchrec .modules .embedding_configs import EmbeddingBagConfig
Original file line number Diff line number Diff line change 32
32
from torchrec .distributed .planner .constants import NUM_POOLINGS , POOLING_FACTOR
33
33
from torchrec .distributed .planner .planners import HeteroEmbeddingShardingPlanner
34
34
from torchrec .distributed .planner .types import ParameterConstraints
35
- from torchrec .distributed .test_utils .test_input import ModelInput
35
+ from torchrec .distributed .test_utils .model_input import ModelInput
36
36
from torchrec .distributed .test_utils .test_model import (
37
37
TestEBCSharder ,
38
38
TestSparseNN ,
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change 40
40
get_module_to_default_sharders ,
41
41
table_wise ,
42
42
)
43
+ from torchrec .distributed .test_utils .model_input import ModelInput
43
44
44
45
from torchrec .distributed .test_utils .multi_process import (
45
46
MultiProcessContext ,
46
47
MultiProcessTestBase ,
47
48
)
48
- from torchrec .distributed .test_utils .test_input import ModelInput
49
49
from torchrec .distributed .test_utils .test_model_parallel import ModelParallelTestShared
50
50
from torchrec .distributed .test_utils .test_sharding import (
51
51
copy_state_dict ,
Original file line number Diff line number Diff line change 11
11
12
12
import torch
13
13
from torch import nn
14
- from torchrec .distributed .test_utils .test_input import ModelInput
14
+ from torchrec .distributed .test_utils .model_input import ModelInput
15
15
from torchrec .modules .deepfm import DeepFM , FactorizationMachine
16
16
from torchrec .modules .embedding_modules import EmbeddingBagCollection
17
17
from torchrec .sparse .jagged_tensor import KeyedJaggedTensor , KeyedTensor
Original file line number Diff line number Diff line change 12
12
import torch
13
13
from torch import nn
14
14
from torchrec .datasets .utils import Batch
15
- from torchrec .distributed .test_utils .test_input import ModelInput
15
+ from torchrec .distributed .test_utils .model_input import ModelInput
16
16
from torchrec .modules .crossnet import LowRankCrossNet
17
17
from torchrec .modules .embedding_modules import EmbeddingBagCollection
18
18
from torchrec .modules .mlp import MLP
Original file line number Diff line number Diff line change 14
14
import torch
15
15
from parameterized import parameterized
16
16
from torch .testing import FileCheck # @manual
17
- from torchrec .distributed .test_utils .test_input import ModelInput
17
+ from torchrec .distributed .test_utils .model_input import ModelInput
18
18
from torchrec .fx import symbolic_trace , Tracer
19
19
from torchrec .models .deepfm import (
20
20
DenseArch ,
Original file line number Diff line number Diff line change 16
16
from torch import nn
17
17
from torch .testing import FileCheck # @manual
18
18
from torchrec .datasets .utils import Batch
19
- from torchrec .distributed .test_utils .test_input import ModelInput
19
+ from torchrec .distributed .test_utils .model_input import ModelInput
20
20
from torchrec .fx import symbolic_trace
21
21
from torchrec .ir .serializer import JsonSerializer
22
22
from torchrec .ir .utils import decapsulate_ir_modules , encapsulate_ir_modules
You can’t perform that action at this time.
0 commit comments