|
| 1 | +import argparse |
| 2 | +import time |
| 3 | +import os |
| 4 | +import gc |
| 5 | +import torch |
| 6 | +import math |
| 7 | +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM |
| 8 | + |
| 9 | +def get_args(): |
| 10 | + parser = argparse.ArgumentParser() |
| 11 | + parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers") |
| 12 | + parser.add_argument("--name", type=str, help="Name path", required=True) |
| 13 | + parser.add_argument("--batch_size", default=1, type=int, help="batch size") |
| 14 | + parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark") |
| 15 | + parser.add_argument("--greedy", action="store_true") |
| 16 | + parser.add_argument("--top-k", type=int, default=0) |
| 17 | + parser.add_argument("--top-p", type=float, default=0.) |
| 18 | + |
| 19 | + return parser.parse_args() |
| 20 | + |
| 21 | +def get_max_memory_per_gpu_dict(dtype, model_name): |
| 22 | + """ try to generate the memory map based on what we know about the model and the available hardware """ |
| 23 | + |
| 24 | + # figure out the memory map - the minimum per gpu required to load the model |
| 25 | + n_gpus = torch.cuda.device_count() |
| 26 | + |
| 27 | + if model_name == "bigscience/bloom" and n_gpus == 8 and torch.cuda.get_device_properties(0).total_memory > 79*2**30: |
| 28 | + # hand crafted optimized memory map for 8x80 setup over BLOOM |
| 29 | + # this works with bs=40 |
| 30 | + return {0: '0GIB', 1: '51GIB', 2: '51GIB', 3: '51GIB', 4: '51GIB', 5: '51GIB', 6: '51GIB', 7: '51GIB'} |
| 31 | + |
| 32 | + try: |
| 33 | + # model_params calculation, as we don't have a model yet to do: |
| 34 | + #model_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) |
| 35 | + |
| 36 | + config = AutoConfig.from_pretrained(model_name) |
| 37 | + h = config.n_embed |
| 38 | + l = config.n_layer |
| 39 | + v = config.vocab_size |
| 40 | + # from https://github.com/bigscience-workshop/bigscience/tree/6917a3b5fefcf439d3485ca184b4d9f6ab605150/math#model-sizing |
| 41 | + model_params = l*(12*h**2 + 13*h) + v*h + 4*h |
| 42 | + except: |
| 43 | + print(f"The model {model_name} has a broken config file. Please notify the owner") |
| 44 | + raise |
| 45 | + |
| 46 | + bytes = torch.finfo(dtype).bits / 8 |
| 47 | + param_memory_total_in_bytes = model_params * bytes |
| 48 | + # add 5% since weight sizes aren't the same and some GPU may need more memory |
| 49 | + param_memory_per_gpu_in_bytes = int(param_memory_total_in_bytes / n_gpus * 1.05) |
| 50 | + print(f"Estimating {param_memory_per_gpu_in_bytes/2**30:0.2f}GB per gpu for weights") |
| 51 | + |
| 52 | + # check the real available memory |
| 53 | + # load cuda kernels first and only measure the real free memory after loading (shorter by ~2GB) |
| 54 | + torch.ones(1).cuda() |
| 55 | + max_memory_per_gpu_in_bytes = torch.cuda.mem_get_info(0)[0] |
| 56 | + if max_memory_per_gpu_in_bytes < param_memory_per_gpu_in_bytes: |
| 57 | + raise ValueError(f"Unable to generate the memory map automatically as the needed estimated memory per gpu ({param_memory_per_gpu_in_bytes/2**30:0.2f}GB) is bigger than the available per gpu memory ({max_memory_per_gpu_in_bytes/2**30:0.2f}GB)") |
| 58 | + |
| 59 | + return {i: param_memory_per_gpu_in_bytes for i in range(torch.cuda.device_count())} |
| 60 | + |
| 61 | +t_start = time.time() |
| 62 | + |
| 63 | +num_tokens = 100 |
| 64 | + |
| 65 | +args = get_args() |
| 66 | + |
| 67 | +local_rank = int(os.getenv('LOCAL_RANK', '0')) |
| 68 | +world_size = int(os.getenv('WORLD_SIZE', '1')) |
| 69 | + |
| 70 | +rank = local_rank |
| 71 | + |
| 72 | +model_name = args.name |
| 73 | +if rank == 0: |
| 74 | + print(f"Loading model {model_name}") |
| 75 | + |
| 76 | + |
| 77 | +tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 78 | + |
| 79 | +# XXX: can't automatically derive dtype via config's `from_pretrained` |
| 80 | +dtype = torch.bfloat16 if model_name in ["bigscience/bloom", "bigscience/bigscience-small-testing"] else torch.float16 |
| 81 | + |
| 82 | +#print(get_max_memory_per_gpu_dict()) |
| 83 | + |
| 84 | + |
| 85 | +model = AutoModelForCausalLM.from_pretrained( |
| 86 | + model_name, |
| 87 | + device_map="auto", |
| 88 | + max_memory=get_max_memory_per_gpu_dict(dtype, model_name), |
| 89 | + torch_dtype=dtype, |
| 90 | +) |
| 91 | + |
| 92 | + |
| 93 | +if args.benchmark: |
| 94 | + t_ready = time.time() |
| 95 | + |
| 96 | + |
| 97 | + |
| 98 | +### Generate |
| 99 | + |
| 100 | +if rank == 0: |
| 101 | + print(f"*** Starting to generate {num_tokens} tokens with bs={args.batch_size}") |
| 102 | + |
| 103 | +input_sentences = [ |
| 104 | + "DeepSpeed is a machine learning framework", |
| 105 | + "He is working on", |
| 106 | + "He has a", |
| 107 | + "He got all", |
| 108 | + "Everyone is happy and I can", |
| 109 | + "The new movie that got Oscar this year", |
| 110 | + "In the far far distance from our galaxy,", |
| 111 | + "Peace is the only way" |
| 112 | +] |
| 113 | + |
| 114 | +if args.batch_size > len(input_sentences): |
| 115 | + # dynamically extend to support larger bs by repetition |
| 116 | + input_sentences *= math.ceil(args.batch_size / len(input_sentences)) |
| 117 | + |
| 118 | +generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False) |
| 119 | +#generate_kwargs = dict(max_new_tokens=num_tokens, use_cache=False, do_sample=False) |
| 120 | +#generate_kwargs = dict(min_length=num_tokens, max_length=num_tokens, do_sample=False) |
| 121 | + |
| 122 | +if rank == 0: |
| 123 | + print(f"Generate args {generate_kwargs}") |
| 124 | +inputs = input_sentences[:args.batch_size] |
| 125 | +def generate(): |
| 126 | + """ returns a list of zipped inputs, outputs and number of new tokens """ |
| 127 | + |
| 128 | + input_tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True) |
| 129 | + for t in input_tokens: |
| 130 | + if torch.is_tensor(input_tokens[t]): |
| 131 | + input_tokens[t] = input_tokens[t].to("cuda:0") |
| 132 | + |
| 133 | + outputs = model.generate(**input_tokens, **generate_kwargs) |
| 134 | + |
| 135 | + input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids] |
| 136 | + output_tokens_lengths = [x.shape[0] for x in outputs] |
| 137 | + |
| 138 | + total_new_tokens = [o-i for i,o in zip(input_tokens_lengths, output_tokens_lengths)] |
| 139 | + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| 140 | + |
| 141 | + return zip(inputs, outputs, total_new_tokens) |
| 142 | + |
| 143 | +# warmup is a must if measuring speed as it's when all the optimizations are performed |
| 144 | +# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs |
| 145 | +_ = generate() |
| 146 | + |
| 147 | +t_generate_start = time.time() |
| 148 | +generated = generate() |
| 149 | +t_generate_span = time.time() - t_generate_start |
| 150 | +if rank == 0: |
| 151 | + for i,o,_ in generated: |
| 152 | + print(f"{'-'*60}\nin={i}\nout={o}\n") |
| 153 | + |
| 154 | + |
| 155 | +if args.benchmark: |
| 156 | + torch.cuda.empty_cache() |
| 157 | + gc.collect() |
| 158 | + |
| 159 | +### Benchmark |
| 160 | + |
| 161 | +if args.benchmark: |
| 162 | + if rank == 0: |
| 163 | + print(f"*** Running benchmark") |
| 164 | + |
| 165 | + # warm up |
| 166 | + for i in range(1): |
| 167 | + _ = generate() |
| 168 | + torch.cuda.synchronize() |
| 169 | + |
| 170 | + # benchmark |
| 171 | + t0 = time.time() |
| 172 | + cycles = 5 |
| 173 | + total_new_tokens_generated = 0 |
| 174 | + for i in range(cycles): |
| 175 | + generated = generate() |
| 176 | + total_new_tokens_generated += sum(new_tokens for _,_,new_tokens in generated) |
| 177 | + torch.cuda.synchronize() |
| 178 | + if rank == 0: |
| 179 | + througput = (time.time() - t0)/(total_new_tokens_generated) |
| 180 | + print(f""" |
| 181 | +*** Performance stats: |
| 182 | +Throughput per token including tokenize: {througput*1000:.2f} msecs |
| 183 | +Start to ready to generate: {t_ready - t_start:.3f} secs |
| 184 | +Tokenize and generate {total_new_tokens_generated} (bs={args.batch_size}) tokens: {t_generate_span:.3f} secs |
| 185 | +Start to finish: {t_ready - t_start + t_generate_span:.3f} secs |
| 186 | +""") |
0 commit comments