Skip to content

Commit 34f4823

Browse files
committed
refactor: Move retry_on_exception decorator to subset selection utils
- Relocate retry_on_exception decorator from subset_selection.py to subset_selection_utils.py - Remove unnecessary imports in subset_selection.py - Simplify code structure and improve module organization - Maintain existing error handling and retry logic Signed-off-by: eshwarprasadS <eshwarprasad.s01@gmail.com>
1 parent 5fb24ba commit 34f4823

File tree

3 files changed

+57
-54
lines changed

3 files changed

+57
-54
lines changed

src/instructlab/sdg/encoders/arctic_encoder.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
# Standard
44
from dataclasses import dataclass
55
from typing import Dict, List, Optional, TypedDict, Union
6-
import os
76
import logging
7+
import os
8+
89
# Third Party
910
from tqdm import tqdm
1011
from transformers import AutoModel, AutoTokenizer
@@ -110,7 +111,7 @@ def _prepare_inputs(
110111
if isinstance(texts, str):
111112
texts = [texts]
112113

113-
#Ensure we always have an instruction
114+
# Ensure we always have an instruction
114115
if not instruction and not self.cfg.use_default_instruction:
115116
raise ValueError(
116117
"An instruction must be provided when use_default_instruction is False. "
@@ -124,7 +125,7 @@ def _prepare_inputs(
124125
):
125126
instruction = str(self.cfg.model_config["default_instruction"])
126127

127-
if not instruction: #catch if default_instruction is empty
128+
if not instruction: # catch if default_instruction is empty
128129
raise ValueError(
129130
"No instruction available. Either provide an instruction or ensure "
130131
"the model config has a valid default_instruction."

src/instructlab/sdg/subset_selection.py

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Standard
22
from dataclasses import dataclass, field
3-
from functools import wraps
43
from multiprocessing import Pool
54
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
65
import gc
@@ -9,7 +8,6 @@
98
import math
109
import os
1110
import re
12-
import time
1311

1412
# Third Party
1513
from datasets import concatenate_datasets, load_dataset
@@ -21,10 +19,11 @@
2119
import torch
2220

2321
# Local
24-
# from .encoders.arctic_encoder import ArcticEmbedEncoder
25-
from .utils.subset_selection_utils import compute_pairwise_dense, get_default_num_gpus
26-
27-
__DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22+
from .utils.subset_selection_utils import (
23+
compute_pairwise_dense,
24+
get_default_num_gpus,
25+
retry_on_exception,
26+
)
2827

2928
# Type variables
3029
T = TypeVar("T")
@@ -130,51 +129,6 @@ def __post_init__(self):
130129
raise ValueError("Absolute values in subset_sizes must be positive")
131130

132131

133-
def retry_on_exception(func):
134-
"""
135-
Decorator to retry a function upon exception up to a maximum number of retries.
136-
"""
137-
138-
@wraps(func)
139-
def wrapper(self, *args, **kwargs):
140-
last_exception = None
141-
for attempt in range(self.config.system.max_retries):
142-
try:
143-
return func(self, *args, **kwargs)
144-
except torch.cuda.OutOfMemoryError as e:
145-
# Happens when GPU runs out of memory during batch processing
146-
last_exception = e
147-
logger.error(f"GPU out of memory on attempt {attempt + 1}: {str(e)}")
148-
except RuntimeError as e:
149-
# Common PyTorch errors (including some OOM errors and model issues)
150-
last_exception = e
151-
logger.error(
152-
f"PyTorch runtime error on attempt {attempt + 1}: {str(e)}"
153-
)
154-
except ValueError as e:
155-
# From tokenizer or input validation
156-
last_exception = e
157-
logger.error(f"Value error on attempt {attempt + 1}: {str(e)}")
158-
except TypeError as e:
159-
# From incorrect input types or model parameter mismatches
160-
last_exception = e
161-
logger.error(f"Type error on attempt {attempt + 1}: {str(e)}")
162-
except IndexError as e:
163-
# Possible during tensor operations or batch processing
164-
last_exception = e
165-
logger.error(f"Index error on attempt {attempt + 1}: {str(e)}")
166-
167-
if attempt < self.config.system.max_retries - 1:
168-
logger.info(f"Retrying in {self.config.system.retry_delay} seconds...")
169-
time.sleep(self.config.system.retry_delay)
170-
gc.collect()
171-
torch.cuda.empty_cache()
172-
173-
raise last_exception
174-
175-
return wrapper
176-
177-
178132
class DataProcessor:
179133
"""
180134
Enhanced data processor with support for combined files and multiple selection methods.

src/instructlab/sdg/utils/subset_selection_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Standard
2+
from functools import wraps
23
from typing import Optional, Union
4+
import gc
35
import logging
6+
import time
47

58
# Third Party
69
from torch import Tensor
@@ -14,6 +17,51 @@
1417
logger = logging.getLogger(__name__)
1518

1619

20+
def retry_on_exception(func):
21+
"""
22+
Decorator to retry a function upon exception up to a maximum number of retries.
23+
"""
24+
25+
@wraps(func)
26+
def wrapper(self, *args, **kwargs):
27+
last_exception = None
28+
for attempt in range(self.config.system.max_retries):
29+
try:
30+
return func(self, *args, **kwargs)
31+
except torch.cuda.OutOfMemoryError as e:
32+
# Happens when GPU runs out of memory during batch processing
33+
last_exception = e
34+
logger.error(f"GPU out of memory on attempt {attempt + 1}: {str(e)}")
35+
except RuntimeError as e:
36+
# Common PyTorch errors (including some OOM errors and model issues)
37+
last_exception = e
38+
logger.error(
39+
f"PyTorch runtime error on attempt {attempt + 1}: {str(e)}"
40+
)
41+
except ValueError as e:
42+
# From tokenizer or input validation
43+
last_exception = e
44+
logger.error(f"Value error on attempt {attempt + 1}: {str(e)}")
45+
except TypeError as e:
46+
# From incorrect input types or model parameter mismatches
47+
last_exception = e
48+
logger.error(f"Type error on attempt {attempt + 1}: {str(e)}")
49+
except IndexError as e:
50+
# Possible during tensor operations or batch processing
51+
last_exception = e
52+
logger.error(f"Index error on attempt {attempt + 1}: {str(e)}")
53+
54+
if attempt < self.config.system.max_retries - 1:
55+
logger.info(f"Retrying in {self.config.system.retry_delay} seconds...")
56+
time.sleep(self.config.system.retry_delay)
57+
gc.collect()
58+
torch.cuda.empty_cache()
59+
60+
raise last_exception
61+
62+
return wrapper
63+
64+
1765
def get_default_num_gpus() -> int:
1866
"""Get the default number of GPUs based on available CUDA devices."""
1967
if not torch.cuda.is_available():

0 commit comments

Comments
 (0)