|
62 | 62 | from .source_transformation.custom_kv_cache import ( |
63 | 63 | replace_kv_cache_with_custom_kv_cache, |
64 | 64 | replace_kv_cache_with_quantized_kv_cache, |
| 65 | + replace_kv_cache_with_ring_kv_cache, |
65 | 66 | ) |
66 | 67 |
|
67 | 68 | from .source_transformation.quantize import ( |
@@ -147,6 +148,23 @@ def build_model( |
147 | 148 | return export_llama(args) |
148 | 149 |
|
149 | 150 |
|
| 151 | +def parse_list_of_ints(s): |
| 152 | + import ast |
| 153 | + |
| 154 | + try: |
| 155 | + parsed = ast.literal_eval(s) |
| 156 | + if isinstance(parsed, list) and all(isinstance(i, int) for i in parsed): |
| 157 | + print(parsed) |
| 158 | + return parsed |
| 159 | + raise argparse.ArgumentTypeError( |
| 160 | + "Must be a list of integers, e.g., [0, 16, 0, 16]" |
| 161 | + ) |
| 162 | + except Exception: |
| 163 | + raise argparse.ArgumentTypeError( |
| 164 | + "Must be a list of integers, e.g., [0, 16, 0, 16]" |
| 165 | + ) |
| 166 | + |
| 167 | + |
150 | 168 | def build_args_parser() -> argparse.ArgumentParser: |
151 | 169 | parser = argparse.ArgumentParser() |
152 | 170 | parser.add_argument("-o", "--output-dir", default=".", help="output directory") |
@@ -357,6 +375,13 @@ def build_args_parser() -> argparse.ArgumentParser: |
357 | 375 | help="maximum length of context for model to remember", |
358 | 376 | ) |
359 | 377 |
|
| 378 | + parser.add_argument( |
| 379 | + "--local_global_attention", |
| 380 | + type=parse_list_of_ints, |
| 381 | + default=None, |
| 382 | + help="List of integers specifying local and global attention pattern, e.g., [0, 16, 0, 16].", |
| 383 | + ) |
| 384 | + |
360 | 385 | parser.add_argument("-2", "--fairseq2", action="store_true") |
361 | 386 | parser.add_argument("-v", "--verbose", action="store_true") |
362 | 387 | parser.add_argument( |
@@ -1297,6 +1322,14 @@ def _get_source_transforms( # noqa |
1297 | 1322 | if args.vulkan: |
1298 | 1323 | transforms.append(replace_with_vulkan_rotary_emb) |
1299 | 1324 |
|
| 1325 | + if args.local_global_attention: |
| 1326 | + transforms.append( |
| 1327 | + partial( |
| 1328 | + replace_kv_cache_with_ring_kv_cache, |
| 1329 | + layer_sizes=args.local_global_attention, |
| 1330 | + ) |
| 1331 | + ) |
| 1332 | + |
1300 | 1333 | return transforms |
1301 | 1334 |
|
1302 | 1335 |
|
|
0 commit comments