We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 1b50d93 commit 409b356Copy full SHA for 409b356
src/accelerate/utils/modeling.py
@@ -1041,7 +1041,8 @@ def get_balanced_memory(
1041
1042
# Compute mean of final modules. In the first dict of module sizes, leaves are the parameters
1043
leaves = get_module_leaves(module_sizes)
1044
- module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves}
+ leaves_set = set(leaves) # Convert to set for O(1) membership testing
1045
+ module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves_set}
1046
# Once removed, leaves are the final modules.
1047
1048
mean_leaves = int(sum([module_sizes[n] for n in leaves]) / max(len(leaves), 1))
0 commit comments