Skip to content

Commit 0608655

Browse files
author
liord
committed
Fix bug in list indexes that caused ViT models to raise errors. In addition, modified number of astar iterations from 500 to 1000 in compute_graph_max_cut to ensure ViT models function correctly
1 parent aa2a182 commit 0608655

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
def compute_graph_max_cut(memory_graph: MemoryGraph,
2929
n_iter: int = 50,
30-
astar_n_iter: int = 500,
30+
astar_n_iter: int = 1000,
3131
eps: float = 1e-2) -> Tuple[List[BaseNode], float, List[Cut]]:
3232
"""
3333
A wrapper function to compute max cut and schedule for a given model.

model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _get_matmul_node(self, attention_node_name: str, q_node: BaseNode, transpose
103103
matmul_name = f'{attention_node_name}_matmul1'
104104
return FunctionalNode(name=matmul_name,
105105
framework_attr={},
106-
input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape)),
106+
input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape[0])),
107107
output_shape=tuple(matmul1_output_shape),
108108
weights={},
109109
layer_class=torch.matmul,

0 commit comments

Comments
 (0)