Skip to content

Commit 5fb24ba

Browse files
author
Cloud User
committed
feat: Improve encoder and subset selection robustness
- Add robust instruction handling in Arctic encoder - Enhance error checking for instruction and configuration - Refactor device handling in subset selection utilities - Remove hardcoded device selection - Simplify logging and error management Signed-off-by: Cloud User <ec2-user@ip-172-31-44-225.ec2.internal>
1 parent 7f6168c commit 5fb24ba

File tree

3 files changed

+25
-25
lines changed

3 files changed

+25
-25
lines changed

src/instructlab/sdg/encoders/arctic_encoder.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import dataclass
55
from typing import Dict, List, Optional, TypedDict, Union
66
import os
7-
7+
import logging
88
# Third Party
99
from tqdm import tqdm
1010
from transformers import AutoModel, AutoTokenizer
@@ -13,13 +13,14 @@
1313
import torch.distributed as dist
1414
import torch.nn.functional as F
1515

16+
logger = logging.getLogger(__name__)
1617
os.environ["TOKENIZERS_PARALLELISM"] = "false"
1718

1819

1920
def safe_print(rank, msg):
2021
"""Only print from rank 0."""
2122
if rank == 0:
22-
print(msg, flush=True)
23+
logger.info(msg)
2324

2425

2526
# Define model configuration
@@ -97,7 +98,7 @@ def _initialize_model(self) -> None:
9798
self.model = self.model.to(self.cfg.device)
9899

99100
if self.cfg.num_gpus > 1:
100-
print(f"Using {self.cfg.num_gpus} GPUs")
101+
logger.info(f"Using {self.cfg.num_gpus} GPUs")
101102
self.model = torch.nn.DataParallel(self.model)
102103

103104
self.model.eval()
@@ -109,15 +110,27 @@ def _prepare_inputs(
109110
if isinstance(texts, str):
110111
texts = [texts]
111112

113+
#Ensure we always have an instruction
114+
if not instruction and not self.cfg.use_default_instruction:
115+
raise ValueError(
116+
"An instruction must be provided when use_default_instruction is False. "
117+
"Either provide an instruction or set use_default_instruction to True."
118+
)
119+
112120
if (
113121
not instruction
114122
and self.cfg.use_default_instruction
115123
and self.cfg.model_config["default_instruction"]
116124
):
117125
instruction = str(self.cfg.model_config["default_instruction"])
118126

119-
if instruction:
120-
texts = [f"{instruction}: {text}" for text in texts]
127+
if not instruction: #catch if default_instruction is empty
128+
raise ValueError(
129+
"No instruction available. Either provide an instruction or ensure "
130+
"the model config has a valid default_instruction."
131+
)
132+
133+
texts = [f"{instruction}: {text}" for text in texts]
121134
return texts
122135

123136
@torch.no_grad()

src/instructlab/sdg/subset_selection.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,6 @@ class ProcessingConfig:
116116

117117
def __post_init__(self):
118118
"""Validate configuration after initialization."""
119-
if not self.input_files:
120-
raise ValueError("input_files cannot be empty")
121-
122119
if not isinstance(self.subset_sizes, list):
123120
raise ValueError("subset_sizes must be a list")
124121

@@ -903,19 +900,8 @@ def subset_datasets(
903900
)
904901

905902
try:
906-
# logger.info(f"Processing configuration: {config}")
907-
908-
# # Initialize data processor based on encoder type
909-
# os.makedirs(config.basic.output_dir, exist_ok=True)
910-
911-
# if config.encoder.encoder_type == "arctic":
912-
# processor = DataProcessor(config, ArcticEmbedEncoder)
913-
# else:
914-
# supported_encoders = get_supported_encoders()
915-
# raise ValueError(
916-
# f"Unsupported encoder type: {config.encoder.encoder_type}."
917-
# f"Supported types are: {', '.join(supported_encoders)}"
918-
# )
903+
logger.info(f"Processing configuration: {config}")
904+
919905
processor = DataProcessor(
920906
config, get_encoder_class(config.encoder.encoder_type)
921907
)

src/instructlab/sdg/utils/subset_selection_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
# Standard
2-
from typing import Optional
2+
from typing import Optional, Union
33
import logging
44

55
# Third Party
66
from torch import Tensor
77
from torch.nn import functional as F
88
import torch
99

10-
__DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11-
1210
# Configure logging
1311
logging.basicConfig(
1412
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@@ -30,13 +28,16 @@ def compute_pairwise_dense(
3028
tensor2: Optional[Tensor] = None,
3129
batch_size: int = 10000,
3230
metric: str = "cosine",
33-
device: str = __DEVICE,
31+
device: Optional[Union[str, torch.device]] = None,
3432
scaling: Optional[str] = None,
3533
kw: float = 0.1,
3634
) -> Tensor:
3735
"""Compute pairwise metric in batches between two sets of vectors."""
3836
assert batch_size > 0, "Batch size must be positive."
3937

38+
if not device:
39+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40+
4041
if tensor2 is None:
4142
tensor2 = tensor1
4243

0 commit comments

Comments
 (0)