Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 103 additions & 1 deletion graph_net/torch/sym_dim_reifiers/naive_cv_sym_dim_reifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@ def match(self) -> bool:
if self.dyn_dim_cstrs is None:
return False
sym_shapes_str = self.dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str()
return sym_shapes_str in self._get_map_cv_sym_shapes_str2reifier()
if sym_shapes_str == "[]":
return False
elif sym_shapes_str not in self._get_map_cv_sym_shapes_str2reifier():
print(
f"[CV SymDim Reify] No reifier matched symbolic shapes:{sym_shapes_str} \nPlease add a reify rule to _get_map_cv_sym_shapes_str2reifier()"
)
return False
return True

def reify(self):
assert self.match()
Expand All @@ -40,6 +47,15 @@ def _get_map_cv_sym_shapes_str2reifier(cls):
"[(S0,80,S1),(S0,)]": cls.reify_nemo_asr_s0_s1,
"[(S0,3,512,1024)]": cls.reify_semantic_seg_s0,
"[(S0,3,640,640)]": cls.reify_yolo_s0,
"[(S0,)]": cls.reify_s0,
"[(S0,S1,S2)]": cls.reify_s0_s1_s2,
"[(S0,S1,S2,S2)]": cls.reify_s0_s1_s2,
"[(S0,3,S1,S2,S2)]": cls.reify_s0_s1_s2,
"[(S0,S1,S1,384)]": cls.reify_s0_s1_384,
"[(S0,S1),(S0,S1,2560)]": cls.reify_s0_s1_2560,
"[(S0,S1,256),(S0,S1,256)]": cls.reify_s0_s1_256,
"[(S0,S1),(S0,S1,3072)]": cls.reify_s0_s1_128_1024,
"[(S0,S1),(S0,S1,1024),(S0,S1,1024)]": cls.reify_s0_s1_128_1024,
}
return cls.g_cv_sym_shapes_str2reifier

Expand Down Expand Up @@ -184,3 +200,89 @@ def reify_yolo_s0(self):
return {
S0: [[1], [2], [4], [8], [12], [16], [24], [32], [64]],
}

def reify_s0(self):
S0 = (sympy.Symbol("S0"),)
return {
S0: [[1], [2], [4], [8], [16], [32], [48], [64], [128]],
}

def reify_s0_s1_s2(self):
S0S1S2 = (sympy.Symbol("S0"), sympy.Symbol("S1"), sympy.Symbol("S2"))
return {
S0S1S2: [
[1, 8, 8],
[1, 8, 16],
[1, 16, 16],
[2, 16, 16],
[1, 16, 32],
[2, 16, 32],
[4, 16, 32],
[1, 32, 32],
[2, 32, 32],
],
}

def reify_s0_s1_384(self):
S0S1 = (sympy.Symbol("S0"), sympy.Symbol("S1"))
return {
S0S1: [
[1, 64],
[1, 128],
[1, 192],
[1, 224],
[1, 256],
[4, 224],
[8, 224],
[32, 224],
[64, 224],
],
}

def reify_s0_s1_2560(self):
S0S1 = (sympy.Symbol("S0"), sympy.Symbol("S1"))
return {
S0S1: [
[1, 16],
[1, 32],
[2, 32],
[1, 64],
[2, 64],
[1, 128],
[2, 128],
[1, 256],
[2, 256],
],
}

def reify_s0_s1_256(self):
S0S1 = (sympy.Symbol("S0"), sympy.Symbol("S1"))
return {
S0S1: [
[1, 8],
[1, 16],
[2, 16],
[1, 32],
[2, 32],
[4, 32],
[1, 64],
[2, 64],
[4, 64],
],
}

def reify_s0_s1_128_1024(self):
S0S1 = (sympy.Symbol("S0"), sympy.Symbol("S1"))
return {
S0S1: [
[1, 8],
[1, 16],
[2, 16],
[1, 32],
[2, 32],
[4, 32],
[1, 64],
[2, 64],
[4, 64],
],
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@ def match(self) -> bool:
if self.dyn_dim_cstrs is None:
return False
sym_shapes_str = self.dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str()
return sym_shapes_str in self._get_map_nlp_sym_shapes_str2reifier()
if sym_shapes_str == "[]":
return False
elif sym_shapes_str not in self._get_map_nlp_sym_shapes_str2reifier():
print(
f"[NLP SymDim Reify] No reifier matched symbolic shapes:{sym_shapes_str} \nPlease add a reify rule to _get_map_nlp_sym_shapes_str2reifier()"
)
return False
return True

def reify(self):
assert self.match()
Expand Down
3 changes: 2 additions & 1 deletion samples/timm/sequencer2d_l.in1k/graph_net.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
"num_nodes_required": 1,
"dynamic": true,
"source": "timm",
"heuristic_tag": "computer_vision"
"heuristic_tag": "computer_vision",
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
}
11 changes: 10 additions & 1 deletion samples/torchgeometric/MetaPath2Vec/graph_net.json
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
{"framework": "torch", "num_devices_required": 1, "num_nodes_required": 1, "dynamic": false, "source": "torchgeometric", "heuristic_tag": "other", "dimension_generalization_passes": []}
{
"framework": "torch",
"num_devices_required": 1,
"num_nodes_required": 1,
"dynamic": false,
"source": "torchgeometric",
"heuristic_tag": "other",
"dimension_generalization_passes": [],
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
}
11 changes: 10 additions & 1 deletion samples/torchgeometric/Node2Vec/graph_net.json
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
{"framework": "torch", "num_devices_required": 1, "num_nodes_required": 1, "dynamic": false, "source": "torchgeometric", "heuristic_tag": "other", "dimension_generalization_passes": []}
{
"framework": "torch",
"num_devices_required": 1,
"num_nodes_required": 1,
"dynamic": false,
"source": "torchgeometric",
"heuristic_tag": "other",
"dimension_generalization_passes": [],
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
}
10 changes: 9 additions & 1 deletion samples/torchvision/googlenet/graph_net.json
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
{"framework": "torch", "num_devices_required": 1, "num_nodes_required": 1, "source": "torchvision", "heuristic_tag": "computer_vision", "dimension_generalization_passes": []}
{
"framework": "torch",
"num_devices_required": 1,
"num_nodes_required": 1,
"source": "torchvision",
"heuristic_tag": "computer_vision",
"dimension_generalization_passes": [],
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
}
10 changes: 9 additions & 1 deletion samples/torchvision/inception_v3/graph_net.json
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
{"framework": "torch", "num_devices_required": 1, "num_nodes_required": 1, "source": "torchvision", "heuristic_tag": "computer_vision", "dimension_generalization_passes": []}
{
"framework": "torch",
"num_devices_required": 1,
"num_nodes_required": 1,
"source": "torchvision",
"heuristic_tag": "computer_vision",
"dimension_generalization_passes": [],
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
}
10 changes: 9 additions & 1 deletion samples/torchvision/r3d_18/graph_net.json
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
{"framework": "torch", "num_devices_required": 1, "num_nodes_required": 1, "source": "torchvision", "heuristic_tag": "computer_vision", "dimension_generalization_passes": []}
{
"framework": "torch",
"num_devices_required": 1,
"num_nodes_required": 1,
"source": "torchvision",
"heuristic_tag": "computer_vision",
"dimension_generalization_passes": [],
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
"license:apache-2.0",
"region:us"
],
"heuristic_tag": "other"
"heuristic_tag": "other",
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
}
3 changes: 2 additions & 1 deletion samples/transformers-auto-model/UAE-Large-V1/graph_net.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
"endpoints_compatible",
"region:us"
],
"heuristic_tag": "other"
"heuristic_tag": "other",
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@
"non_batch_call_function_full_plus_one_pass",
"non_batch_call_function_zeros_pass",
"non_batch_call_function_arange_plus_one_pass"
]
],
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,6 @@
"non_batch_call_function_full_plus_one_pass",
"non_batch_call_function_zeros_pass",
"non_batch_call_function_arange_plus_one_pass"
]
],
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@
"non_batch_call_function_full_plus_one_pass",
"non_batch_call_function_zeros_pass",
"non_batch_call_function_arange_plus_one_pass"
]
],
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@
"non_batch_call_function_full_plus_one_pass",
"non_batch_call_function_zeros_pass",
"non_batch_call_function_arange_plus_one_pass"
]
],
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@
"non_batch_call_function_full_plus_one_pass",
"non_batch_call_function_zeros_pass",
"non_batch_call_function_arange_plus_one_pass"
]
],
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
"endpoints_compatible",
"region:us"
],
"heuristic_tag": "other"
"heuristic_tag": "other",
"symbolic_dimension_reifier": "naive_cv_sym_dim_reifier"
}