Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions caikit_nlp/config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ embedding:
autocast: false
# For testing, set device to "mps" on MacOS or "xpu" for IPEX GPU.
# Otherwise, the default does automatic checks for cuda GPU (else cpu).
graphmode: false
# Use graph mode with IPEX CPU
device: ""

runtime:
Expand Down
76 changes: 61 additions & 15 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Standard
from collections.abc import Sized
from contextlib import nullcontext
from enum import Enum, auto
from typing import Any, Callable, Dict, List, NamedTuple, Optional, TypeVar, Union
import importlib
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument

AUTOCAST = env_val_to_bool(val=embedding_cfg.get("autocast"))
IPEX = env_val_to_bool(val=embedding_cfg.get("ipex"))
GRAPH_MODE = env_val_to_bool(val=embedding_cfg.get("graphmode"))
PT2_COMPILE = env_val_to_bool(val=embedding_cfg.get("pt2_compile"))
RETRIES = env_val_to_int(val=embedding_cfg.get("retries"), default=0)
BATCH_SIZE = env_val_to_int(val=embedding_cfg.get("batch_size"), default=0)
Expand Down Expand Up @@ -801,6 +803,14 @@ def sum_token_count(


class SentenceTransformerWithTruncate(SentenceTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if GRAPH_MODE:
# Initialize the compiled model right after the base class initialization
self.compiled_model = (
self._apply_graph_mode()
) # Compile and store the graph model

def _truncate_input_tokens(
self,
truncate_input_tokens: int,
Expand Down Expand Up @@ -904,6 +914,44 @@ def _truncate_input_tokens(

return TruncatedTokensTuple(tokenized, input_token_count)

def _apply_graph_mode(self) -> torch.jit.ScriptModule:
"""
Compiles the model into a TorchScript graph using predefined fixed-size randomized
input tensors.The tensors simulate typical input structures without relying
on actual input feature data.

:return: A TorchScript graph that is optimized for inference.
"""
self.eval()

max_seq_length = self.max_seq_length
vocab_size = self.tokenizer.vocab_size

# Generate random input_ids within the vocabulary range and a full attention mask
input_ids = torch.randint(low=0, high=vocab_size, size=(1, max_seq_length))
attention_mask = torch.ones(1, max_seq_length).int()

# Context manager for automatic mixed precision, if applicable
context_manager = torch.cpu.amp.autocast() if AUTOCAST else nullcontext()

with torch.no_grad(), context_manager:
# Trace the model with the synthetic input to create a TorchScript graph
compiled_graph = torch.jit.trace(
self,
(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
}
),
strict=False,
)

# Freeze the compiled graph to optimize it for runtime performance
compiled_graph = torch.jit.freeze(compiled_graph)

return compiled_graph

def encode(
self,
sentences: Union[str, List[str]],
Expand Down Expand Up @@ -954,7 +1002,7 @@ def encode(
output_value,
normalize_embeddings,
)

# torchscript requires eval mode
self.eval()

if convert_to_tensor:
Expand Down Expand Up @@ -999,20 +1047,18 @@ def encode(

features = batch_to_device(features, device)

if AUTOCAST:
with torch.no_grad(), torch.cpu.amp.autocast():
out_features = self.forward(features)
embeddings = out_features["sentence_embedding"]
if convert_to_numpy:
embeddings = embeddings.detach().cpu()
all_embeddings.extend(embeddings)
else:
with torch.no_grad():
out_features = self.forward(features)
embeddings = out_features["sentence_embedding"]
if convert_to_numpy:
embeddings = embeddings.detach().cpu()
all_embeddings.extend(embeddings)
# Determine which model to use based on GRAPH_MODE
model_to_use = self.compiled_model if GRAPH_MODE else self.forward

# Execution context based on AUTOCAST
context_manager = torch.cpu.amp.autocast() if AUTOCAST else nullcontext()

with torch.no_grad(), context_manager:
out_features = model_to_use(features)
embeddings = out_features["sentence_embedding"]
if convert_to_numpy:
embeddings = embeddings.detach().cpu()
all_embeddings.extend(embeddings)

# Restore original order
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
Expand Down