Skip to content

Commit ac16dd7

Browse files
commit suggestions
1 parent 839622c commit ac16dd7

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

evaluation/generation/generate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def get_args():
1212
parser.add_argument("--generate-max-length", type=int, default=50, help="max generation length")
1313
parser.add_argument("--greedy", action="store_true")
1414
parser.add_argument("--top-k", type=int, default=0)
15+
parser.add_argument("--offload_folder", type=str, help="offload folder for accelerate", default="./offload")
1516

1617
return parser.parse_args()
1718

@@ -40,6 +41,7 @@ def main():
4041
device_map="auto" if args.parallelize else None,
4142
torch_dtype=torch.bfloat16,
4243
revision="gs{}".format(args.global_step) if args.global_step else None
44+
offload_floder=args.offload_folder,
4345
)
4446
print(f"Loaded model in {datetime.datetime.now() - start}")
4547

0 commit comments

Comments
 (0)