From 797541e36d637c6dbda2aa9cc6eb287340a89885 Mon Sep 17 00:00:00 2001 From: devpramod Date: Thu, 16 May 2024 20:46:13 +0000 Subject: [PATCH 1/4] add graph mode Signed-off-by: devpramod --- caikit_nlp/config/config.yml | 6 +- .../modules/text_embedding/embedding.py | 67 +++++++++++++++---- 2 files changed, 57 insertions(+), 16 deletions(-) diff --git a/caikit_nlp/config/config.yml b/caikit_nlp/config/config.yml index 6f440a22..4d8bb4f3 100644 --- a/caikit_nlp/config/config.yml +++ b/caikit_nlp/config/config.yml @@ -47,11 +47,13 @@ embedding: # Attempt to optimize with PyTorch compile() pt2_compile: false # Use IPEX optimize. Works best when used with autocast (bfloat16) below. - ipex: false + ipex: true # Use autocast in encode with its default dtype (bfloat16) - autocast: false + autocast: true # For testing, set device to "mps" on MacOS or "xpu" for IPEX GPU. # Otherwise, the default does automatic checks for cuda GPU (else cpu). + graphmode: true + # Use graph mode with IPEX CPU device: "" runtime: diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index c1eb7173..dd2cc255 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -17,6 +17,7 @@ from enum import Enum, auto from typing import Any, Callable, Dict, List, NamedTuple, Optional, TypeVar, Union import importlib +from contextlib import nullcontext import os import time @@ -84,6 +85,7 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument AUTOCAST = env_val_to_bool(val=embedding_cfg.get("autocast")) IPEX = env_val_to_bool(val=embedding_cfg.get("ipex")) +GRAPH_MODE = env_val_to_bool(val=embedding_cfg.get("graphmode")) PT2_COMPILE = env_val_to_bool(val=embedding_cfg.get("pt2_compile")) RETRIES = env_val_to_int(val=embedding_cfg.get("retries"), default=0) BATCH_SIZE = env_val_to_int(val=embedding_cfg.get("batch_size"), default=0) @@ -903,6 +905,40 @@ def _truncate_input_tokens( ) return TruncatedTokensTuple(tokenized, input_token_count) + + def _apply_graph_mode(self) -> torch.jit.ScriptModule: + """ + Compiles the model into a TorchScript graph using predefined fixed-size randomized input tensors. + The tensors simulate typical input structures without relying on actual input feature data. + + :return: A TorchScript graph that is optimized for inference. + """ + max_seq_length = self.max_seq_length + vocab_size = self.tokenizer.vocab_size + + # Generate random input_ids within the vocabulary range and a full attention mask + input_ids = torch.randint(low=0, high=vocab_size, size=(1, max_seq_length)) + attention_mask = torch.ones(1, max_seq_length).int() + + # Context manager for automatic mixed precision, if applicable + context_manager = torch.cpu.amp.autocast() if AUTOCAST else nullcontext() + + with torch.no_grad(), context_manager: + # Trace the model with the synthetic input to create a TorchScript graph + compiled_graph = torch.jit.trace( + self, + ({ + "input_ids": input_ids, + "attention_mask": attention_mask, + }), + strict=False, + ) + + # Freeze the compiled graph to optimize it for runtime performance + compiled_graph = torch.jit.freeze(compiled_graph) + + return compiled_graph + def encode( self, @@ -999,20 +1035,23 @@ def encode( features = batch_to_device(features, device) - if AUTOCAST: - with torch.no_grad(), torch.cpu.amp.autocast(): - out_features = self.forward(features) - embeddings = out_features["sentence_embedding"] - if convert_to_numpy: - embeddings = embeddings.detach().cpu() - all_embeddings.extend(embeddings) - else: - with torch.no_grad(): - out_features = self.forward(features) - embeddings = out_features["sentence_embedding"] - if convert_to_numpy: - embeddings = embeddings.detach().cpu() - all_embeddings.extend(embeddings) + # load model that will be JIT compiled as a graph + if GRAPH_MODE: + # compiled_model = self._apply_graph_mode(features) + compiled_model = self._apply_graph_mode() + + # Determine which model to use based on GRAPH_MODE + model_to_use = compiled_model if GRAPH_MODE else self.forward + + # Execution context based on AUTOCAST + context_manager = torch.cpu.amp.autocast() if AUTOCAST else nullcontext() + + with torch.no_grad(), context_manager: + out_features = model_to_use(features) + embeddings = out_features["sentence_embedding"] + if convert_to_numpy: + embeddings = embeddings.detach().cpu() + all_embeddings.extend(embeddings) # Restore original order all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] From 5a67e19d4e63c0e72ec923651d939674b9126543 Mon Sep 17 00:00:00 2001 From: devpramod Date: Fri, 17 May 2024 15:23:19 +0000 Subject: [PATCH 2/4] graph mode compiled in constructor Signed-off-by: devpramod --- caikit_nlp/config/config.yml | 6 +++--- .../modules/text_embedding/embedding.py | 19 +++++++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/caikit_nlp/config/config.yml b/caikit_nlp/config/config.yml index 4d8bb4f3..09d6af89 100644 --- a/caikit_nlp/config/config.yml +++ b/caikit_nlp/config/config.yml @@ -47,12 +47,12 @@ embedding: # Attempt to optimize with PyTorch compile() pt2_compile: false # Use IPEX optimize. Works best when used with autocast (bfloat16) below. - ipex: true + ipex: false # Use autocast in encode with its default dtype (bfloat16) - autocast: true + autocast: false # For testing, set device to "mps" on MacOS or "xpu" for IPEX GPU. # Otherwise, the default does automatic checks for cuda GPU (else cpu). - graphmode: true + graphmode: false # Use graph mode with IPEX CPU device: "" diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index dd2cc255..aa771f97 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -803,6 +803,12 @@ def sum_token_count( class SentenceTransformerWithTruncate(SentenceTransformer): + def __init__(self, *args, **kwargs): + super(SentenceTransformerWithTruncate, self).__init__(*args, **kwargs) + if GRAPH_MODE: + # Initialize the compiled model right after the base class initialization + self.compiled_model = self._apply_graph_mode() # Compile and store the graph model + def _truncate_input_tokens( self, truncate_input_tokens: int, @@ -913,6 +919,8 @@ def _apply_graph_mode(self) -> torch.jit.ScriptModule: :return: A TorchScript graph that is optimized for inference. """ + self.eval() + max_seq_length = self.max_seq_length vocab_size = self.tokenizer.vocab_size @@ -936,7 +944,7 @@ def _apply_graph_mode(self) -> torch.jit.ScriptModule: # Freeze the compiled graph to optimize it for runtime performance compiled_graph = torch.jit.freeze(compiled_graph) - + return compiled_graph @@ -990,7 +998,7 @@ def encode( output_value, normalize_embeddings, ) - + # torchscript requires eval mode self.eval() if convert_to_tensor: @@ -1035,13 +1043,8 @@ def encode( features = batch_to_device(features, device) - # load model that will be JIT compiled as a graph - if GRAPH_MODE: - # compiled_model = self._apply_graph_mode(features) - compiled_model = self._apply_graph_mode() - # Determine which model to use based on GRAPH_MODE - model_to_use = compiled_model if GRAPH_MODE else self.forward + model_to_use = self.compiled_model if GRAPH_MODE else self.forward # Execution context based on AUTOCAST context_manager = torch.cpu.amp.autocast() if AUTOCAST else nullcontext() From c9c5245605f694cffaebb695233418ffa39fda55 Mon Sep 17 00:00:00 2001 From: devpramod Date: Mon, 3 Jun 2024 18:25:41 +0000 Subject: [PATCH 3/4] formatting and linting Signed-off-by: devpramod --- .../modules/text_embedding/embedding.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index aa771f97..13e336fc 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -14,10 +14,10 @@ # Standard from collections.abc import Sized +from contextlib import nullcontext from enum import Enum, auto from typing import Any, Callable, Dict, List, NamedTuple, Optional, TypeVar, Union import importlib -from contextlib import nullcontext import os import time @@ -804,11 +804,13 @@ def sum_token_count( class SentenceTransformerWithTruncate(SentenceTransformer): def __init__(self, *args, **kwargs): - super(SentenceTransformerWithTruncate, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) if GRAPH_MODE: # Initialize the compiled model right after the base class initialization - self.compiled_model = self._apply_graph_mode() # Compile and store the graph model - + self.compiled_model = ( + self._apply_graph_mode() + ) # Compile and store the graph model + def _truncate_input_tokens( self, truncate_input_tokens: int, @@ -911,11 +913,12 @@ def _truncate_input_tokens( ) return TruncatedTokensTuple(tokenized, input_token_count) - + def _apply_graph_mode(self) -> torch.jit.ScriptModule: """ - Compiles the model into a TorchScript graph using predefined fixed-size randomized input tensors. - The tensors simulate typical input structures without relying on actual input feature data. + Compiles the model into a TorchScript graph using predefined fixed-size randomized + input tensors.The tensors simulate typical input structures without relying + on actual input feature data. :return: A TorchScript graph that is optimized for inference. """ @@ -923,7 +926,7 @@ def _apply_graph_mode(self) -> torch.jit.ScriptModule: max_seq_length = self.max_seq_length vocab_size = self.tokenizer.vocab_size - + # Generate random input_ids within the vocabulary range and a full attention mask input_ids = torch.randint(low=0, high=vocab_size, size=(1, max_seq_length)) attention_mask = torch.ones(1, max_seq_length).int() @@ -935,18 +938,19 @@ def _apply_graph_mode(self) -> torch.jit.ScriptModule: # Trace the model with the synthetic input to create a TorchScript graph compiled_graph = torch.jit.trace( self, - ({ - "input_ids": input_ids, - "attention_mask": attention_mask, - }), + ( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + ), strict=False, ) # Freeze the compiled graph to optimize it for runtime performance compiled_graph = torch.jit.freeze(compiled_graph) - - return compiled_graph + return compiled_graph def encode( self, From 5f542602391bd9b5023b7705a4313921f988bb83 Mon Sep 17 00:00:00 2001 From: devpramod Date: Mon, 3 Jun 2024 20:53:43 +0000 Subject: [PATCH 4/4] minor formatting fix Signed-off-by: devpramod --- caikit_nlp/modules/text_embedding/embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 13e336fc..0a85b72d 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -916,8 +916,8 @@ def _truncate_input_tokens( def _apply_graph_mode(self) -> torch.jit.ScriptModule: """ - Compiles the model into a TorchScript graph using predefined fixed-size randomized - input tensors.The tensors simulate typical input structures without relying + Compiles the model into a TorchScript graph using predefined fixed-size randomized + input tensors.The tensors simulate typical input structures without relying on actual input feature data. :return: A TorchScript graph that is optimized for inference.