Skip to content

Commit 3254880

Browse files
author
liord
committed
Fix bug that caused ViT models to raise errors.
1 parent aa2a182 commit 3254880

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
def compute_graph_max_cut(memory_graph: MemoryGraph,
2929
n_iter: int = 50,
30-
astar_n_iter: int = 500,
30+
astar_n_iter: int = 1000,
3131
eps: float = 1e-2) -> Tuple[List[BaseNode], float, List[Cut]]:
3232
"""
3333
A wrapper function to compute max cut and schedule for a given model.

model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _get_matmul_node(self, attention_node_name: str, q_node: BaseNode, transpose
103103
matmul_name = f'{attention_node_name}_matmul1'
104104
return FunctionalNode(name=matmul_name,
105105
framework_attr={},
106-
input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape)),
106+
input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape[0])),
107107
output_shape=tuple(matmul1_output_shape),
108108
weights={},
109109
layer_class=torch.matmul,

0 commit comments

Comments
 (0)