-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_pipeline.py
More file actions
114 lines (95 loc) · 3.26 KB
/
data_pipeline.py
File metadata and controls
114 lines (95 loc) · 3.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from nexus.data.inputs import InputProcessor, InputConfig
from nexus.data.augmentation import AugmentationPipeline, MixupAugmentation
from nexus.data.cache import DataCache
from nexus.data.streaming import StreamingDataset
import torch
from typing import Dict, Optional
# Setup input processing
input_config = InputConfig(
input_type="image",
image_size=(224, 224),
normalize=True,
augment=True
)
processor = InputProcessor(input_config)
augmentation = AugmentationPipeline(
image_size=(224, 224),
augmentation_strength=0.8
)
mixup = MixupAugmentation(alpha=0.2)
cache = DataCache()
# Create data pipeline
def data_generator(batch_size: int = 32):
while True:
try:
# Get data from source
data = get_next_batch(batch_size)
# Process inputs
processed = processor.process(data)
# Apply augmentation
if input_config.augment:
processed["image"] = augmentation(processed["image"])
# Cache processed data
cache_key = cache._get_cache_key(processed)
cache.save(cache_key, processed)
yield processed
except Exception as e:
print(f"Error in data generation: {str(e)}")
continue
# Create streaming dataset
dataset = StreamingDataset(
data_generator(),
buffer_size=1000
)
def train_step(batch: Dict[str, torch.Tensor], training: bool = True) -> Dict[str, torch.Tensor]:
"""
Performs a single training step with optional mixup augmentation.
Args:
batch: Dictionary containing image and label tensors
training: Whether in training mode (enables mixup)
Returns:
Dictionary containing processed batch data
"""
if training:
# Apply mixup augmentation
mixed_images, labels_a, labels_b, lam = mixup(
batch["image"],
batch["labels"]
)
return {
"image": mixed_images,
"labels_a": labels_a,
"labels_b": labels_b,
"lam": lam
}
return batch
def get_next_batch(batch_size: int = 32) -> Dict[str, torch.Tensor]:
"""
Fetches the next batch of data for processing.
Args:
batch_size: Number of samples to fetch in this batch
Returns:
Dictionary containing:
- "image": Tensor of shape (batch_size, channels, height, width)
- "labels": Tensor of shape (batch_size,) containing class labels
"""
# This is a placeholder implementation
# In a real application, you would:
# 1. Load data from your dataset
# 2. Convert to tensors
# 3. Apply any basic preprocessing
# Simulate loading image data
images = torch.randn(batch_size, 3, 224, 224) # Random RGB images
labels = torch.randint(0, 1000, (batch_size,)) # Random class labels
return {
"image": images,
"labels": labels
}
# Example usage
if __name__ == "__main__":
# Training loop
for batch in dataset:
# Process batch with mixup during training
processed_batch = train_step(batch, training=True)
# Use processed batch for model training
# Your training code here...