From 5b78df61e136de4c8c4ac03baf937f3159240a01 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 12 Mar 2025 22:28:53 -0700 Subject: [PATCH 1/7] unclear why this won't compile Summary: when not in prefill, it doesn't compile, but when i pad so more than one token, still doesn't compile. Test Plan: Reviewers: Subscribers: Tasks: Tags: --- mixtral-moe/README.md | 2 +- mixtral-moe/generate.py | 16 +++- mixtral-moe/model.py | 152 ++++++++++++++++++++++++++++++- mixtral-moe/run.sh | 4 + scripts/convert_hf_checkpoint.py | 5 +- 5 files changed, 169 insertions(+), 10 deletions(-) create mode 100644 mixtral-moe/run.sh diff --git a/mixtral-moe/README.md b/mixtral-moe/README.md index cf5e9d9b..bfe1c2b1 100644 --- a/mixtral-moe/README.md +++ b/mixtral-moe/README.md @@ -4,7 +4,7 @@ ## Downloading Weights ```bash -export MODEL_REPO=mistralai/Mixtral-8x7B-v0.1 +export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1 python scripts/download.py --repo_id $MODEL_REPO python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO ``` diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index 9aa076b6..7e0d254f 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -74,6 +74,8 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc next_token, next_prob = decode_one_token( model, cur_token, input_pos, **sampling_kwargs ) + next_token, next_prob = next_token.clone(), next_prob.clone() + input_pos += 1 new_tokens.append(next_token.clone()) callback(new_tokens[-1]) @@ -117,7 +119,6 @@ def generate( empty[:T] = prompt seq = empty input_pos = torch.arange(0, T, device=device) - next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) seq[T] = next_token @@ -144,8 +145,12 @@ def _load_model(checkpoint_path, device, precision, use_tp): simple_quantizer = WeightOnlyBit8QuantHandler(model, torch.int8) model = simple_quantizer.convert_for_runtime() - checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) - model.load_state_dict(checkpoint, assign=True) + try: + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + except: + model = Transformer.from_name(checkpoint_path.parent.name) + if use_tp: from tp import apply_tp @@ -172,7 +177,7 @@ def main( ) -> None: """Generates text samples based on a pre-trained Transformer model and tokenizer. """ - assert checkpoint_path.is_file(), checkpoint_path + # assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), str(tokenizer_path) @@ -289,7 +294,8 @@ def callback(x): import argparse parser = argparse.ArgumentParser(description='Your CLI description.') - parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') + # parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') + parser.add_argument('--prompt', type=str, default="H", help='Input prompt.') parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') diff --git a/mixtral-moe/model.py b/mixtral-moe/model.py index 9249ac9d..aabcd435 100644 --- a/mixtral-moe/model.py +++ b/mixtral-moe/model.py @@ -52,7 +52,7 @@ def from_name(cls, name: str): transformer_configs = { - "Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2), + "Mixtral-8x7B-Instruct-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2), } class KVCache(nn.Module): @@ -122,13 +122,16 @@ class TransformerBlock(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.attention = Attention(config) - self.block_sparse_moe = MOEFeedForward(config) + self.block_sparse_moe = MOEFeedForwardAOQuantizable(config) + # self.block_sparse_moe = MOEFeedForward(config) self.ffn_norm = RMSNorm(config.dim, config.norm_eps) self.attention_norm = RMSNorm(config.dim, config.norm_eps) def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) - out = h + self.block_sparse_moe(self.ffn_norm(h)) + moe_out = self.block_sparse_moe(self.ffn_norm(h)) + # import fbvscode; fbvscode.set_trace() + out = h + moe_out return out @@ -258,3 +261,146 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: x_out2 = x_out2.flatten(3) return x_out2.type_as(x) + + +# T tokens +# E experts +# D dim +# I intermediate dim +# A activated experts +# T'(e) tokens for expert e + +class MOEFeedForwardAOQuantizable(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) + self.cond_ffn = ConditionalFeedForwardAOQuantizable(config) + self.dim = config.dim + self.num_activated_experts = config.num_activated_experts + def forward(self, x: Tensor) -> Tensor: + x = x.view(-1, self.dim) # x: [T, D] + padded=False + if x.shape[0] == 1: + padded=True + x = F.pad(x, (0,0,0,1)) + scores = self.gate(x) # [T, E] + expert_weights = F.softmax(scores, dim=-1) + expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] + out = self.cond_ffn(x, expert_indices, expert_weights, self.num_activated_experts) + if padded: + return out[:-1] + return out + + +class ConditionalFeedForwardAOQuantizable(nn.Module): + def __init__(self, config): + super().__init__() + self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D + self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) # E, D, I + self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D + self.num_experts = config.num_experts + def forward( + self, x: Tensor, # T, D + expert_indices: Tensor, # T, A + expert_weights: Tensor, # T, A + num_activated_experts: int, + ) -> Tensor: + + if x.shape[0]==1: + # version 1 [should be fastest] + # out = torch.zeros_like(x) # T, D + # for activated_expert_idx in range(num_activated_experts): + # cur_expert = expert_indices[:, activated_expert_idx].squeeze() + # cur_weight = expert_weights[:, activated_expert_idx] # T' + + # w1=self.w1[cur_expert] # I, D + # w2=self.w2[cur_expert] # D, I + # w3=self.w3[cur_expert] # I, D + + # cur_out = F.linear( F.silu(F.linear(x, w1)) * F.linear(x, w3), w2) # T', D + # out += cur_out * cur_weight + + # base version + # w1_weights = self.w1[expert_indices] # [T, A, D, D] + # w3_weights = self.w3[expert_indices] # [T, A, D, D] + # w2_weights = self.w2[expert_indices] # [T, A, D, D] + # x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) + # x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) + + # expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) + # return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) + + # modified base version + # w1 = self.w1[expert_indices].reshape(-1, self.w1.shape[-1]) # [T, A, D, D] + # w3 = self.w3[expert_indices].reshape(-1, self.w3.shape[-1]) # [T, A, D, D] + # w2_weights = self.w2[expert_indices] + + # x1n = F.silu(F.linear(x, w1)) + # x3n = F.linear(x, w3) + # # x_finals = (x1n*x3n).split(self.w2.shape[-1], dim=1) + # x_finals = (x1n*x3n).reshape(1, num_activated_experts, -1) + # expert_outs2 = torch.einsum('tao, taio -> tai', x_finals, w2_weights) + # return (expert_outs2 * expert_weights.unsqueeze(-1)).sum(dim=1) + + # general version + # tok_indices_per_expert, tok_weights_per_expert = get_indices_and_weights_per_expert(expert_indices, expert_weights, self.num_experts) + # expert_list = [x for x in range(self.num_experts)] + + # augmented general version [why doesn't this work] + tok_indices_per_expert, tok_weights_per_expert = get_indices_and_weights_per_expert(F.pad(expert_indices, (0,0,0,1)), F.pad(expert_weights, (0,0,0,1)), self.num_experts) + expert_list = [x for x in range(self.num_experts)] + x = F.pad(x,(0,0,0,1)) + + out = torch.zeros_like(x) # T, D + for activated_expert_idx, expert in enumerate(expert_list): + w1=self.w1[expert] # I, D + w2=self.w2[expert] # D, I + w3=self.w3[expert] # I, D + + tok_indices = tok_indices_per_expert[activated_expert_idx] + cur_x = x[tok_indices] # T', D + cur_weights = tok_weights_per_expert[activated_expert_idx] # T' + + cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # T', D + out[tok_indices] += cur_out * cur_weights + return out[:-1] + else: + # This works for both cases but isn't quantizable when only 1 token + tok_indices_per_expert, tok_weights_per_expert = get_indices_and_weights_per_expert(expert_indices, expert_weights, self.num_experts) + expert_list = [x for x in range(self.num_experts)] + # tok_indices_per_expert, tok_weights_per_expert: list([T'(e0) ,T'(e1) , ...]) + + out = torch.zeros_like(x) # T, D + for activated_expert_idx, expert in enumerate(expert_list): + w1=self.w1[expert] # I, D + w2=self.w2[expert] # D, I + w3=self.w3[expert] # I, D + + tok_indices = tok_indices_per_expert[activated_expert_idx] + cur_x = x[tok_indices] # T', D + cur_weights = tok_weights_per_expert[activated_expert_idx] # T' + + cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # T', D + out[tok_indices] += cur_out * cur_weights + + return out + + +def get_indices_and_weights_per_expert(expert_indices, expert_weights, num_experts): + num_tokens, experts_per_token = expert_indices.shape + # extract the tokens used by each expert by sorting expert_indices by expert, then indexing the sorted list based on how many tokens each expert gets + # expert_indices = [[0, 1] [1, 3], [0, 2]] + # tokens_per_expert = [2, 2, 1, 1] -> cum_tokens_per_expert = [0, 2, 4, 5, 6] (want 0 in front to make things easier) + # sorted_tokens_by_expert = [0, 2, 0, 1, 2, 1] + # tok_indices_per_expert = [|0, 2 | 0, 1 | 2, 1 |] -> [[0, 2] [0, 1] [2] [1]] + sorted_token_activation_by_expert = expert_indices.view(-1).argsort(dim=0) # + sorted_tokens_by_expert = (sorted_token_activation_by_expert/experts_per_token).floor().to(torch.int) + cum_tokens_per_expert = torch.histc(expert_indices, bins=num_experts+1, min=-1, max=num_experts).cumsum(0) + tok_indices_per_expert = [sorted_tokens_by_expert[cum_tokens_per_expert[i]:cum_tokens_per_expert[i+1]] for i in range(num_experts)] + + # arrange weights in same way as tokens and then group weights by expert + sorted_weights_by_expert = expert_weights.view(-1)[sorted_token_activation_by_expert] + tok_weights_per_expert = [sorted_weights_by_expert[cum_tokens_per_expert[i]:cum_tokens_per_expert[i+1]].view(-1, 1) for i in range(num_experts)] + + return tok_indices_per_expert, tok_weights_per_expert diff --git a/mixtral-moe/run.sh b/mixtral-moe/run.sh new file mode 100644 index 00000000..22f7ee05 --- /dev/null +++ b/mixtral-moe/run.sh @@ -0,0 +1,4 @@ +export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1 + +# python generate.py --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth +python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index f14ba6ca..67e451dd 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -93,7 +93,10 @@ def permute(w, n_head): if "layers" in key: abstract_key = re.sub(r'(\d+)', '{}', key) layer_num = re.search(r'\d+', key).group(0) - new_key = weight_map[abstract_key] + try: + new_key = weight_map[abstract_key] + except: + import fbvscode; fbvscode.set_trace() if new_key is None: continue new_key = new_key.format(layer_num) From 851a89a2d99cbebef61f8bcbc752318d4a16eceb Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 12 Mar 2025 22:31:06 -0700 Subject: [PATCH 2/7] swapping to base version Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- mixtral-moe/model.py | 52 ++++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/mixtral-moe/model.py b/mixtral-moe/model.py index aabcd435..f71eb851 100644 --- a/mixtral-moe/model.py +++ b/mixtral-moe/model.py @@ -309,17 +309,17 @@ def forward( if x.shape[0]==1: # version 1 [should be fastest] - # out = torch.zeros_like(x) # T, D - # for activated_expert_idx in range(num_activated_experts): - # cur_expert = expert_indices[:, activated_expert_idx].squeeze() - # cur_weight = expert_weights[:, activated_expert_idx] # T' + out = torch.zeros_like(x) # T, D + for activated_expert_idx in range(num_activated_experts): + cur_expert = expert_indices[:, activated_expert_idx].squeeze() + cur_weight = expert_weights[:, activated_expert_idx] # T' - # w1=self.w1[cur_expert] # I, D - # w2=self.w2[cur_expert] # D, I - # w3=self.w3[cur_expert] # I, D + w1=self.w1[cur_expert] # I, D + w2=self.w2[cur_expert] # D, I + w3=self.w3[cur_expert] # I, D - # cur_out = F.linear( F.silu(F.linear(x, w1)) * F.linear(x, w3), w2) # T', D - # out += cur_out * cur_weight + cur_out = F.linear( F.silu(F.linear(x, w1)) * F.linear(x, w3), w2) # T', D + out += cur_out * cur_weight # base version # w1_weights = self.w1[expert_indices] # [T, A, D, D] @@ -331,7 +331,7 @@ def forward( # expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) # return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) - # modified base version + # modified base version (slightly faster) # w1 = self.w1[expert_indices].reshape(-1, self.w1.shape[-1]) # [T, A, D, D] # w3 = self.w3[expert_indices].reshape(-1, self.w3.shape[-1]) # [T, A, D, D] # w2_weights = self.w2[expert_indices] @@ -348,23 +348,23 @@ def forward( # expert_list = [x for x in range(self.num_experts)] # augmented general version [why doesn't this work] - tok_indices_per_expert, tok_weights_per_expert = get_indices_and_weights_per_expert(F.pad(expert_indices, (0,0,0,1)), F.pad(expert_weights, (0,0,0,1)), self.num_experts) - expert_list = [x for x in range(self.num_experts)] - x = F.pad(x,(0,0,0,1)) - - out = torch.zeros_like(x) # T, D - for activated_expert_idx, expert in enumerate(expert_list): - w1=self.w1[expert] # I, D - w2=self.w2[expert] # D, I - w3=self.w3[expert] # I, D - - tok_indices = tok_indices_per_expert[activated_expert_idx] - cur_x = x[tok_indices] # T', D - cur_weights = tok_weights_per_expert[activated_expert_idx] # T' + # tok_indices_per_expert, tok_weights_per_expert = get_indices_and_weights_per_expert(F.pad(expert_indices, (0,0,0,1)), F.pad(expert_weights, (0,0,0,1)), self.num_experts) + # expert_list = [x for x in range(self.num_experts)] + # x = F.pad(x,(0,0,0,1)) - cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # T', D - out[tok_indices] += cur_out * cur_weights - return out[:-1] + # out = torch.zeros_like(x) # T, D + # for activated_expert_idx, expert in enumerate(expert_list): + # w1=self.w1[expert] # I, D + # w2=self.w2[expert] # D, I + # w3=self.w3[expert] # I, D + + # tok_indices = tok_indices_per_expert[activated_expert_idx] + # cur_x = x[tok_indices] # T', D + # cur_weights = tok_weights_per_expert[activated_expert_idx] # T' + + # cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # T', D + # out[tok_indices] += cur_out * cur_weights + # return out[:-1] else: # This works for both cases but isn't quantizable when only 1 token tok_indices_per_expert, tok_weights_per_expert = get_indices_and_weights_per_expert(expert_indices, expert_weights, self.num_experts) From d3bc46aba84267ad680f5de8605e6ffc2fa59dfb Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 31 Mar 2025 10:42:06 -0700 Subject: [PATCH 3/7] testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- mixtral-moe/generate.py | 129 +++++++++++++++++------ mixtral-moe/model.py | 171 +++++++++++++------------------ mixtral-moe/run.sh | 30 +++++- scripts/convert_hf_checkpoint.py | 5 +- 4 files changed, 200 insertions(+), 135 deletions(-) diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index 7e0d254f..61653c76 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -25,7 +25,7 @@ def device_sync(device): torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future - +torch._dynamo.config.capture_scalar_outputs = True # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -52,7 +52,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non return probs def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): - probs = logits_to_probs(logits[0, -1], temperature, top_k) + probs = logits_to_probs(logits[:, -1], temperature, top_k) idx_next = multinomial_sample_one_no_sync(probs) return idx_next, probs @@ -80,7 +80,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc new_tokens.append(next_token.clone()) callback(new_tokens[-1]) new_probs.append(next_prob.clone()) - cur_token = next_token.view(1, -1) + cur_token = next_token return new_tokens, new_probs @@ -93,6 +93,7 @@ def generate( model: Transformer, prompt: torch.Tensor, max_new_tokens: int, + batch_size: int, *, interactive: bool, callback = lambda x: x, @@ -101,31 +102,30 @@ def generate( """ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. """ + device, dtype = prompt.device, prompt.dtype + + + T = prompt.size(-1) + max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + new_tokens = max_seq_length - T + + # duplicate prompt for batch_size + prompt = prompt.repeat(batch_size, 1) # create an empty tensor of the expected final shape and fill in the current tokens - T = prompt.size(0) - T_new = T + max_new_tokens - if interactive: - max_seq_length = 350 - else: - max_seq_length = min(T_new, model.config.block_size) + seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device) + seq[:, :T] = prompt - device, dtype = prompt.device, prompt.dtype with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length) - # create an empty tensor of the expected final shape and fill in the current tokens - empty = torch.empty(T_new, dtype=dtype, device=device) - empty[:T] = prompt - seq = empty input_pos = torch.arange(0, T, device=device) - next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) - seq[T] = next_token + next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs) + seq[:, T] = next_token.squeeze() input_pos = torch.tensor([T], device=device, dtype=torch.int) - - generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) - seq[T + 1:] = torch.cat(generated_tokens) + generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) + seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1) return seq @@ -167,6 +167,7 @@ def main( interactive: bool = False, num_samples: int = 5, max_new_tokens: int = 100, + batch_size: int = 1, top_k: int = 200, temperature: float = 0.8, checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"), @@ -178,7 +179,6 @@ def main( """Generates text samples based on a pre-trained Transformer model and tokenizer. """ # assert checkpoint_path.is_file(), checkpoint_path - tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), str(tokenizer_path) @@ -207,13 +207,77 @@ def main( torch.manual_seed(1234) model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) + + + import torchao + from torchao.quantization import quantize_, Int8WeightOnlyConfig + # quantize_(model, Int8WeightOnlyConfig()) + + + from torchao.quantization.quant_primitives import MappingType + from torchao.dtypes import to_affine_quantized_intx + + def moe_filter(module, fqn): + return "MOEFeedForwardAOQuantizable" in str(type(module)) + + def cond_ffn_filter(module, fqn): + return "ConditionalFeedForwardAOQuantizable" in str(type(module)) + + def quant_convert_fn(module, config): + def quant_tensor(weight): + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + block_size = [1 for x in range(param.dim())] + block_size[-1] = param.shape[-1] + block_size = tuple(block_size) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + ) + return new_weight + assert "ConditionalFeedForwardAOQuantizable" in str(type(module)) + assert hasattr(module, "w1") + assert hasattr(module, "w2") + assert hasattr(module, "w3") + + group_size = None if config.group_size is None else config.group_size + for weight_attr in ["w1", "w2", "w3"]: + param = getattr(module, weight_attr) + new_param = quant_tensor(param) + new_param = torch.nn.Parameter(new_param, requires_grad=False) + setattr(module, weight_attr, new_param) + del param + return module + + from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + + # _replace_with_custom_fn_if_matches_filter( + # model, + # quant_convert_fn, + # cond_ffn_filter, + # extra_args=(Int8WeightOnlyConfig(),) + # ) + + + if compile: torch._inductor.config.assert_indirect_indexing = False + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._dynamo.config.capture_scalar_outputs = True global decode_one_token, prefill - decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + if batch_size > 1: # MoE code has graph break for multi token path so can't fullgraph compile + # decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead") + decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead") + else: + decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) - # Uncomment to squeeze more perf out of prefill if args.compile_prefill: prefill = torch.compile(prefill, fullgraph=True, dynamic=True) @@ -260,6 +324,7 @@ def callback(x): model, encoded, max_new_tokens, + batch_size, interactive=interactive, callback=callback, temperature=temperature, @@ -277,16 +342,19 @@ def callback(x): t = time.perf_counter() - t0 if not interactive: - print(tokenizer.decode(y.tolist())) + print(tokenizer.decode(y[0].tolist())) else: print() - tokens_generated = y.size(0) - prompt_length + tokens_generated = y.size(-1) - prompt_length tokens_sec = tokens_generated / t aggregate_metrics['tokens_per_sec'].append(tokens_sec) print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") - print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") + tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() + print(f"Average tokens/sec: {tokpersec:.2f}") + if batch_size > 1: + print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}") print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") @@ -294,11 +362,12 @@ def callback(x): import argparse parser = argparse.ArgumentParser(description='Your CLI description.') - # parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') - parser.add_argument('--prompt', type=str, default="H", help='Input prompt.') + parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') - parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') + # parser.add_argument('--num_samples', type=int, default=1, help='Number of samples.') + parser.add_argument('--num_samples', type=int, default=2, help='Number of samples.') parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') + parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with') parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') @@ -309,6 +378,6 @@ def callback(x): args = parser.parse_args() main( - args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, + args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device ) diff --git a/mixtral-moe/model.py b/mixtral-moe/model.py index f71eb851..38c8ff13 100644 --- a/mixtral-moe/model.py +++ b/mixtral-moe/model.py @@ -106,8 +106,10 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] x = self.tok_embeddings(idx) - + for i, layer in enumerate(self.layers): + # if i>2: + # break x = layer(x, input_pos, freqs_cis, mask) x = self.norm(x) logits = self.output(x) @@ -130,8 +132,10 @@ def __init__(self, config: ModelArgs) -> None: def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) moe_out = self.block_sparse_moe(self.ffn_norm(h)) - # import fbvscode; fbvscode.set_trace() - out = h + moe_out + try: + out = h + moe_out + except: + import fbvscode; fbvscode.set_trace() return out @@ -278,24 +282,20 @@ def __init__(self, config) -> None: self.dim = config.dim self.num_activated_experts = config.num_activated_experts def forward(self, x: Tensor) -> Tensor: + batch_size = x.shape[0] x = x.view(-1, self.dim) # x: [T, D] - padded=False - if x.shape[0] == 1: - padded=True - x = F.pad(x, (0,0,0,1)) scores = self.gate(x) # [T, E] expert_weights = F.softmax(scores, dim=-1) expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] - expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] + expert_weights /= expert_weights.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] out = self.cond_ffn(x, expert_indices, expert_weights, self.num_activated_experts) - if padded: - return out[:-1] - return out + return out.reshape(batch_size, -1, self.dim) class ConditionalFeedForwardAOQuantizable(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) # E, D, I self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D @@ -306,101 +306,72 @@ def forward( expert_weights: Tensor, # T, A num_activated_experts: int, ) -> Tensor: - - if x.shape[0]==1: - # version 1 [should be fastest] - out = torch.zeros_like(x) # T, D - for activated_expert_idx in range(num_activated_experts): - cur_expert = expert_indices[:, activated_expert_idx].squeeze() - cur_weight = expert_weights[:, activated_expert_idx] # T' - - w1=self.w1[cur_expert] # I, D - w2=self.w2[cur_expert] # D, I - w3=self.w3[cur_expert] # I, D - - cur_out = F.linear( F.silu(F.linear(x, w1)) * F.linear(x, w3), w2) # T', D - out += cur_out * cur_weight - - # base version - # w1_weights = self.w1[expert_indices] # [T, A, D, D] - # w3_weights = self.w3[expert_indices] # [T, A, D, D] - # w2_weights = self.w2[expert_indices] # [T, A, D, D] - # x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) - # x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) + num_tokens, dim = x.shape + num_token_activations = num_tokens * num_activated_experts + + if x.shape[0]==1: #only 1 token (can be done without graph breaks when compiled) + outs = [] + expert_indices=expert_indices.squeeze() + # collect used experts + w1 = self.w1[expert_indices] + w2 = self.w2[expert_indices] + w3 = self.w3[expert_indices] + + # run token through each expert + for index in range(num_activated_experts): + cur_out = F.linear( F.silu(F.linear(x, w1[index])) * F.linear(x, w3[index]), w2[index]) + outs.append(cur_out) + + # combine outputs + final_out = (torch.cat(outs, dim=0) * expert_weights.view(-1,1)).sum(dim=0).unsqueeze(-1) + return final_out + else: + expert_list = [x for x in range(self.num_experts)] - # expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) - # return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) + # shuffle tokens into groups for each expert + ordered_token_activations = expert_indices.view(-1).argsort(stable=True) # [A] + ordered_token_indices = ordered_token_activations.div(num_activated_experts).floor().to(torch.int64) # [T] - # modified base version (slightly faster) - # w1 = self.w1[expert_indices].reshape(-1, self.w1.shape[-1]) # [T, A, D, D] - # w3 = self.w3[expert_indices].reshape(-1, self.w3.shape[-1]) # [T, A, D, D] - # w2_weights = self.w2[expert_indices] + num_tokens_per_expert = torch.histc(expert_indices, bins=self.num_experts+1, min=-1, max=self.num_experts) # [E+1] (added leading 0 so can be used for indexing) + cum_tokens_per_expert = num_tokens_per_expert.cumsum(0) # [E+1] - # x1n = F.silu(F.linear(x, w1)) - # x3n = F.linear(x, w3) - # # x_finals = (x1n*x3n).split(self.w2.shape[-1], dim=1) - # x_finals = (x1n*x3n).reshape(1, num_activated_experts, -1) - # expert_outs2 = torch.einsum('tao, taio -> tai', x_finals, w2_weights) - # return (expert_outs2 * expert_weights.unsqueeze(-1)).sum(dim=1) - - # general version - # tok_indices_per_expert, tok_weights_per_expert = get_indices_and_weights_per_expert(expert_indices, expert_weights, self.num_experts) - # expert_list = [x for x in range(self.num_experts)] - - # augmented general version [why doesn't this work] - # tok_indices_per_expert, tok_weights_per_expert = get_indices_and_weights_per_expert(F.pad(expert_indices, (0,0,0,1)), F.pad(expert_weights, (0,0,0,1)), self.num_experts) - # expert_list = [x for x in range(self.num_experts)] - # x = F.pad(x,(0,0,0,1)) - - # out = torch.zeros_like(x) # T, D - # for activated_expert_idx, expert in enumerate(expert_list): - # w1=self.w1[expert] # I, D - # w2=self.w2[expert] # D, I - # w3=self.w3[expert] # I, D - - # tok_indices = tok_indices_per_expert[activated_expert_idx] - # cur_x = x[tok_indices] # T', D - # cur_weights = tok_weights_per_expert[activated_expert_idx] # T' - - # cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # T', D - # out[tok_indices] += cur_out * cur_weights - # return out[:-1] - else: - # This works for both cases but isn't quantizable when only 1 token - tok_indices_per_expert, tok_weights_per_expert = get_indices_and_weights_per_expert(expert_indices, expert_weights, self.num_experts) - expert_list = [x for x in range(self.num_experts)] - # tok_indices_per_expert, tok_weights_per_expert: list([T'(e0) ,T'(e1) , ...]) + # needed to pull this into a function to apply this decorator since compile doesn't like it + # @torch._dynamo.disable() + # def group_tokens_by_expert(x, ordered_token_indices, cum_tokens_per_expert, expert_list): + # token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]] for expert in expert_list] # [T'(e1)], [T'(e2)] ... + # tokens_grouped_by_expert = [x[indices] for indices in token_indices_per_expert] + # return tokens_grouped_by_expert + # tokens_grouped_by_expert = group_tokens_by_expert(x, ordered_token_indices, cum_tokens_per_expert, expert_list) + + @torch._dynamo.disable() + def group_tokens_by_expert(x, ordered_token_indices, cum_tokens_per_expert, expert_list): + token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]] for expert in expert_list] # [T'(e1)], [T'(e2)] ... + return token_indices_per_expert + + # token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]] for expert in expert_list] # [T'(e1)], [T'(e2)] ... + token_indices_per_expert = group_tokens_by_expert(x, ordered_token_indices, cum_tokens_per_expert, expert_list) + tokens_grouped_by_expert = [x[indices] for indices in token_indices_per_expert] + + + # calculate outputs for each expert + outs = [] + for cur_x, expert in zip(tokens_grouped_by_expert,expert_list): + # if x.shape[0]<24: + # import fbvscode; fbvscode.set_trace() - out = torch.zeros_like(x) # T, D - for activated_expert_idx, expert in enumerate(expert_list): w1=self.w1[expert] # I, D w2=self.w2[expert] # D, I w3=self.w3[expert] # I, D - tok_indices = tok_indices_per_expert[activated_expert_idx] - cur_x = x[tok_indices] # T', D - cur_weights = tok_weights_per_expert[activated_expert_idx] # T' - - cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # T', D - out[tok_indices] += cur_out * cur_weights - - return out + cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # [T'(e), D] + outs.append(cur_out) - -def get_indices_and_weights_per_expert(expert_indices, expert_weights, num_experts): - num_tokens, experts_per_token = expert_indices.shape - # extract the tokens used by each expert by sorting expert_indices by expert, then indexing the sorted list based on how many tokens each expert gets - # expert_indices = [[0, 1] [1, 3], [0, 2]] - # tokens_per_expert = [2, 2, 1, 1] -> cum_tokens_per_expert = [0, 2, 4, 5, 6] (want 0 in front to make things easier) - # sorted_tokens_by_expert = [0, 2, 0, 1, 2, 1] - # tok_indices_per_expert = [|0, 2 | 0, 1 | 2, 1 |] -> [[0, 2] [0, 1] [2] [1]] - sorted_token_activation_by_expert = expert_indices.view(-1).argsort(dim=0) # - sorted_tokens_by_expert = (sorted_token_activation_by_expert/experts_per_token).floor().to(torch.int) - cum_tokens_per_expert = torch.histc(expert_indices, bins=num_experts+1, min=-1, max=num_experts).cumsum(0) - tok_indices_per_expert = [sorted_tokens_by_expert[cum_tokens_per_expert[i]:cum_tokens_per_expert[i+1]] for i in range(num_experts)] - - # arrange weights in same way as tokens and then group weights by expert - sorted_weights_by_expert = expert_weights.view(-1)[sorted_token_activation_by_expert] - tok_weights_per_expert = [sorted_weights_by_expert[cum_tokens_per_expert[i]:cum_tokens_per_expert[i+1]].view(-1, 1) for i in range(num_experts)] - - return tok_indices_per_expert, tok_weights_per_expert + # weigh outputs + ordered_outs = torch.cat(outs, dim=0) # [T*A, D] + ordered_token_activation_weights = expert_weights.view(-1,1)[ordered_token_activations].view(-1,1) # [T*A, 1] + weighted_ordered_outs = ordered_outs*ordered_token_activation_weights # [T*A, D] + + # sum weighted token-activation outputs together for each token + final_out = torch.zeros_like(x) # [T, D] + final_out = final_out.scatter_add(dim=0, index=ordered_token_indices.unsqueeze(-1).expand(num_token_activations,dim), src=weighted_ordered_outs) + return final_out diff --git a/mixtral-moe/run.sh b/mixtral-moe/run.sh index 22f7ee05..05f4cf27 100644 --- a/mixtral-moe/run.sh +++ b/mixtral-moe/run.sh @@ -1,4 +1,32 @@ export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1 # python generate.py --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth -python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth +# echo "1" +# echo "python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 +# " +# python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 +# echo "2" +# echo "python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 +# " +# python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 +# echo "3" +# echo "python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 +# " +# python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 + +python generate.py --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --compile + + +# python generate.py --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 --compile --profile "no_q_profile" + +# quant reduced layers +# Time for inference 2: 2.25 sec total, 88.82 tokens/sec +# Bandwidth achieved: 8296.64 GB/s +# Average tokens/sec: 163.61 +# Memory used: 94.12 GB + +# no quant +# Time for inference 2: 0.52 sec total, 385.94 tokens/sec +# Bandwidth achieved: 36049.12 GB/s +# Average tokens/sec: 385.58 +# Memory used: 93.57 GB diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 67e451dd..f14ba6ca 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -93,10 +93,7 @@ def permute(w, n_head): if "layers" in key: abstract_key = re.sub(r'(\d+)', '{}', key) layer_num = re.search(r'\d+', key).group(0) - try: - new_key = weight_map[abstract_key] - except: - import fbvscode; fbvscode.set_trace() + new_key = weight_map[abstract_key] if new_key is None: continue new_key = new_key.format(layer_num) From c94b0055ce6bb9a7637ce322f4b51534aa0cf68b Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 31 Mar 2025 12:17:46 -0700 Subject: [PATCH 4/7] testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- mixtral-moe/generate.py | 2 +- mixtral-moe/model.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index 61653c76..5c2bd086 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -211,7 +211,7 @@ def main( import torchao from torchao.quantization import quantize_, Int8WeightOnlyConfig - # quantize_(model, Int8WeightOnlyConfig()) + quantize_(model, Int8WeightOnlyConfig()) from torchao.quantization.quant_primitives import MappingType diff --git a/mixtral-moe/model.py b/mixtral-moe/model.py index 38c8ff13..a4d5cdac 100644 --- a/mixtral-moe/model.py +++ b/mixtral-moe/model.py @@ -132,10 +132,7 @@ def __init__(self, config: ModelArgs) -> None: def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) moe_out = self.block_sparse_moe(self.ffn_norm(h)) - try: - out = h + moe_out - except: - import fbvscode; fbvscode.set_trace() + out = h + moe_out return out From 08edaed69c5b20b6f44dc070a31aa587665a117b Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 31 Mar 2025 13:16:13 -0700 Subject: [PATCH 5/7] testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- mixtral-moe/generate.py | 2 +- mixtral-moe/model.py | 4 ++-- mixtral-moe/run.sh | 8 +------- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index 5c2bd086..61653c76 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -211,7 +211,7 @@ def main( import torchao from torchao.quantization import quantize_, Int8WeightOnlyConfig - quantize_(model, Int8WeightOnlyConfig()) + # quantize_(model, Int8WeightOnlyConfig()) from torchao.quantization.quant_primitives import MappingType diff --git a/mixtral-moe/model.py b/mixtral-moe/model.py index a4d5cdac..e2bcc5cf 100644 --- a/mixtral-moe/model.py +++ b/mixtral-moe/model.py @@ -108,8 +108,8 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: x = self.tok_embeddings(idx) for i, layer in enumerate(self.layers): - # if i>2: - # break + if i>2: + break x = layer(x, input_pos, freqs_cis, mask) x = self.norm(x) logits = self.output(x) diff --git a/mixtral-moe/run.sh b/mixtral-moe/run.sh index 05f4cf27..9692f8ad 100644 --- a/mixtral-moe/run.sh +++ b/mixtral-moe/run.sh @@ -14,7 +14,7 @@ export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1 # " # python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 -python generate.py --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --compile +python generate.py --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --compile --profile no_q_new_model # python generate.py --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 --compile --profile "no_q_profile" @@ -24,9 +24,3 @@ python generate.py --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --comp # Bandwidth achieved: 8296.64 GB/s # Average tokens/sec: 163.61 # Memory used: 94.12 GB - -# no quant -# Time for inference 2: 0.52 sec total, 385.94 tokens/sec -# Bandwidth achieved: 36049.12 GB/s -# Average tokens/sec: 385.58 -# Memory used: 93.57 GB From 27fb799a5fac3ca005e77926e2346bc3c836480c Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 31 Mar 2025 14:33:56 -0700 Subject: [PATCH 6/7] test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- mixtral-moe/generate.py | 9 ++++++++- mixtral-moe/model.py | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index 61653c76..1c0a655f 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -211,7 +211,14 @@ def main( import torchao from torchao.quantization import quantize_, Int8WeightOnlyConfig - # quantize_(model, Int8WeightOnlyConfig()) + + + def filter(model, fqn): + return isinstance(model, torch.nn.Linear) and "gate" not in fqn + + quantize_(model, Int8WeightOnlyConfig(), filter_fn=filter) + + quantize_(model, Int8WeightOnlyConfig()) from torchao.quantization.quant_primitives import MappingType diff --git a/mixtral-moe/model.py b/mixtral-moe/model.py index e2bcc5cf..a4d5cdac 100644 --- a/mixtral-moe/model.py +++ b/mixtral-moe/model.py @@ -108,8 +108,8 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: x = self.tok_embeddings(idx) for i, layer in enumerate(self.layers): - if i>2: - break + # if i>2: + # break x = layer(x, input_pos, freqs_cis, mask) x = self.norm(x) logits = self.output(x) From 8a64319ec1372c23556016ebb1c1fcc60aa21e00 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Mon, 31 Mar 2025 14:41:12 -0700 Subject: [PATCH 7/7] test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- mixtral-moe/generate.py | 5 +---- mixtral-moe/model.py | 2 +- mixtral-moe/run.sh | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index 1c0a655f..43e7c585 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -217,8 +217,6 @@ def filter(model, fqn): return isinstance(model, torch.nn.Linear) and "gate" not in fqn quantize_(model, Int8WeightOnlyConfig(), filter_fn=filter) - - quantize_(model, Int8WeightOnlyConfig()) from torchao.quantization.quant_primitives import MappingType @@ -275,8 +273,7 @@ def quant_tensor(weight): if compile: torch._inductor.config.assert_indirect_indexing = False - torch._dynamo.config.capture_dynamic_output_shape_ops = True - torch._dynamo.config.capture_scalar_outputs = True + # torch._dynamo.config.capture_dynamic_output_shape_ops = True global decode_one_token, prefill if batch_size > 1: # MoE code has graph break for multi token path so can't fullgraph compile diff --git a/mixtral-moe/model.py b/mixtral-moe/model.py index a4d5cdac..e930f37c 100644 --- a/mixtral-moe/model.py +++ b/mixtral-moe/model.py @@ -106,7 +106,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] x = self.tok_embeddings(idx) - + for i, layer in enumerate(self.layers): # if i>2: # break diff --git a/mixtral-moe/run.sh b/mixtral-moe/run.sh index 9692f8ad..c92b11f1 100644 --- a/mixtral-moe/run.sh +++ b/mixtral-moe/run.sh @@ -14,7 +14,7 @@ export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1 # " # python generate.py --compile --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 -python generate.py --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --compile --profile no_q_new_model +python generate.py --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --compile # python generate.py --checkpoint_path ../checkpoints/$MODEL_REPO/model.pth --batch_size 4 --compile --profile "no_q_profile"