Skip to content

Commit de3bf20

Browse files
committed
Add skip_sparsity_compression_stats option
1 parent 48e67a1 commit de3bf20

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

src/automation/tasks/llmcompressor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def __init__(
3232
vision_samples: Optional[int]=None,
3333
max_seq_len: int=8192,
3434
trust_remote_code: bool=False,
35-
max_memory_per_gpu: str="hessian",
35+
max_memory_per_gpu: str="auto",
36+
skip_sparsity_compression_stats=True,
3637
tracing_class: Optional[str]=None,
3738
tags: Union[str, List[str]]=None,
3839
task_type: str="training",
@@ -94,6 +95,7 @@ def __init__(
9495
self.model_class = model_class
9596
self.dataset_loader = dataset_loader
9697
self.data_collator = data_collator
98+
self.skip_sparsity_compression_stats = skip_sparsity_compression_stats
9799

98100
if tags is not None:
99101
tags = list(set(config_kwargs.pop("tags", []).extend(tags)))
@@ -140,6 +142,7 @@ def get_arguments(self):
140142
"trust_remote_code": self.trust_remote_code,
141143
"max_memory_per_gpu": self.max_memory_per_gpu,
142144
"tracing_class": self.tracing_class,
145+
"skip_sparsity_compression_stats": self.skip_sparsity_compression_stats,
143146
"tags": self.tags,
144147
},
145148
}

src/automation/tasks/scripts/llmcompressor_script.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def llmcompressor_main(
2727
max_seq_len,
2828
text_samples,
2929
vision_samples,
30+
skip_sparsity_compression_stats,
3031
save_directory,
3132
data_collator,
3233
):
@@ -129,7 +130,7 @@ def llmcompressor_main(
129130
)
130131

131132
# Save model compressed
132-
model.save_pretrained(save_directory, save_compressed=True)
133+
model.save_pretrained(save_directory, save_compressed=True, skip_sparsity_compression_stats=skip_sparsity_compression_stats)
133134
processor.save_pretrained(save_directory)
134135

135136
return recipe
@@ -155,6 +156,7 @@ def main(configurations=None):
155156
recipe = args.get("recipe", None)
156157
recipe_args = args.get("recipe_args", None)
157158
tags = args.get("tags", None)
159+
skip_sparsity_compression_stats = parse_argument(args["skip_sparsity_compression_stats"], bool)
158160

159161
dataset_loader_fn = load_callable_configuration("dataset loader", configurations)
160162
data_collator_fn = load_callable_configuration("data collator", configurations)
@@ -175,6 +177,7 @@ def main(configurations=None):
175177
max_seq_len,
176178
text_samples,
177179
vision_samples,
180+
skip_sparsity_compression_stats,
178181
save_directory,
179182
data_collator_fn,
180183
)

0 commit comments

Comments
 (0)