|
| 1 | +# data_utils.py |
| 2 | +import os |
| 3 | +import random |
| 4 | + |
| 5 | +import paddle |
| 6 | +import pandas as pd |
| 7 | +from PIL import Image |
| 8 | + |
| 9 | + |
| 10 | +def device2str(type=None, index=None, *, device=None): |
| 11 | + type = device if device else type |
| 12 | + if isinstance(type, int): |
| 13 | + type = f"gpu:{type}" |
| 14 | + elif isinstance(type, str): |
| 15 | + if "cuda" in type: |
| 16 | + type = type.replace("cuda", "gpu") |
| 17 | + if "cpu" in type: |
| 18 | + type = "cpu" |
| 19 | + elif index is not None: |
| 20 | + type = f"{type}:{index}" |
| 21 | + elif isinstance(type, paddle.CPUPlace) or (type is None): |
| 22 | + type = "cpu" |
| 23 | + elif isinstance(type, paddle.CUDAPlace): |
| 24 | + type = f"gpu:{type.get_device_id()}" |
| 25 | + return type |
| 26 | + |
| 27 | + |
| 28 | +class CustomDataset(paddle.io.Dataset): |
| 29 | + def __init__(self, data, device="cpu"): |
| 30 | + self.data = data |
| 31 | + self.device = device |
| 32 | + self.preload_to_device() |
| 33 | + |
| 34 | + def preload_to_device(self): |
| 35 | + self.data = [ |
| 36 | + ( |
| 37 | + image.to(self.device), |
| 38 | + group, |
| 39 | + paddle.to_tensor(data=features).astype(dtype="float32").to(self.device), |
| 40 | + ) |
| 41 | + for image, group, features in self.data |
| 42 | + ] |
| 43 | + |
| 44 | + def __len__(self): |
| 45 | + return len(self.data) |
| 46 | + |
| 47 | + def __getitem__(self, index): |
| 48 | + image, group, features = self.data[index] |
| 49 | + return image, group, features |
| 50 | + |
| 51 | + |
| 52 | +image_transforms = paddle.vision.transforms.Compose( |
| 53 | + transforms=[ |
| 54 | + paddle.vision.transforms.CenterCrop(size=224), |
| 55 | + paddle.vision.transforms.ToTensor(), |
| 56 | + paddle.vision.transforms.Normalize( |
| 57 | + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
| 58 | + ), |
| 59 | + ] |
| 60 | +) |
| 61 | + |
| 62 | + |
| 63 | +def make_dataset(data_folder, N=1, verbose=False, device="cpu"): |
| 64 | + random.seed(16) |
| 65 | + this_data = [] |
| 66 | + all_subfolders = [ |
| 67 | + f |
| 68 | + for f in os.listdir(data_folder) |
| 69 | + if os.path.isdir(os.path.join(data_folder, f)) and len(f.split("_")) >= 3 |
| 70 | + ] |
| 71 | + |
| 72 | + def safe_folder_sort_key(x): |
| 73 | + parts = x.split("_") |
| 74 | + try: |
| 75 | + return float(parts[-3]) |
| 76 | + except Exception: |
| 77 | + return float("inf") |
| 78 | + |
| 79 | + subfolders = sorted(all_subfolders, key=safe_folder_sort_key) |
| 80 | + grouped_subfolders = [[] for _ in range(5)] |
| 81 | + for i, subfolder in enumerate(subfolders): |
| 82 | + index = i // (len(subfolders) // 5) |
| 83 | + if index >= 5: |
| 84 | + index = 4 |
| 85 | + grouped_subfolders[index].append(subfolder) |
| 86 | + if verbose: |
| 87 | + print("分组结果:", grouped_subfolders) |
| 88 | + chunk_keys = {} |
| 89 | + for i, gs in enumerate(grouped_subfolders): |
| 90 | + for sf in gs: |
| 91 | + chunk_keys[sf] = i |
| 92 | + sample_keys = {k: i for i, k in enumerate(subfolders)} |
| 93 | + for _ in range(len(subfolders) // 5 + 1): |
| 94 | + for k, group in enumerate(grouped_subfolders): |
| 95 | + if not group: |
| 96 | + continue |
| 97 | + selected_subfolder = random.choice(group) |
| 98 | + group.remove(selected_subfolder) |
| 99 | + folder_path = os.path.join(data_folder, selected_subfolder) |
| 100 | + if not os.path.isdir(folder_path): |
| 101 | + print(f"Warning: {folder_path} is not a valid directory") |
| 102 | + continue |
| 103 | + csv_data = None |
| 104 | + try: |
| 105 | + for file_name in os.listdir(folder_path): |
| 106 | + if file_name.endswith(".csv"): |
| 107 | + csv_path = os.path.join(folder_path, file_name) |
| 108 | + try: |
| 109 | + csv_data = pd.read_csv(csv_path) |
| 110 | + break |
| 111 | + except Exception as e: |
| 112 | + print(f"Error reading CSV file {csv_path}: {str(e)}") |
| 113 | + continue |
| 114 | + except Exception as e: |
| 115 | + print(f"Error accessing directory {folder_path}: {str(e)}") |
| 116 | + continue |
| 117 | + num = 0 |
| 118 | + try: |
| 119 | + image_names = [ |
| 120 | + image_name |
| 121 | + for image_name in os.listdir(folder_path) |
| 122 | + if image_name.endswith(".jpg") |
| 123 | + ] |
| 124 | + image_names.sort() |
| 125 | + except Exception as e: |
| 126 | + print(f"Error reading images from {folder_path}: {str(e)}") |
| 127 | + continue |
| 128 | + for i, image_name in enumerate(image_names): |
| 129 | + if i % N != 0: |
| 130 | + continue |
| 131 | + num += 1 |
| 132 | + image_path = os.path.join(folder_path, image_name) |
| 133 | + image_data = Image.open(image_path).convert("RGB") |
| 134 | + image_data = image_transforms(image_data) |
| 135 | + if csv_data is not None: |
| 136 | + image_features = ( |
| 137 | + csv_data.loc[csv_data["Image Name"] == image_name, "UTS (MPa)"] |
| 138 | + .values[0] |
| 139 | + .astype(float) |
| 140 | + ) |
| 141 | + else: |
| 142 | + image_features = None |
| 143 | + this_data.append( |
| 144 | + ( |
| 145 | + image_data, |
| 146 | + ( |
| 147 | + chunk_keys[selected_subfolder], |
| 148 | + sample_keys[selected_subfolder], |
| 149 | + ), |
| 150 | + image_features, |
| 151 | + ) |
| 152 | + ) |
| 153 | + if verbose: |
| 154 | + print(f"文件夹 {selected_subfolder} 采样图片数: {num}") |
| 155 | + return CustomDataset(this_data, device=device) |
0 commit comments