diff --git a/data.py b/data.py index 2cca6cf..0d8aab7 100644 --- a/data.py +++ b/data.py @@ -34,19 +34,81 @@ def encode(self, s: str, bos: bool, eos: bool) -> List[int]: def decode(self, t: List[int]) -> str: return self.tokenizer.decode(t) -class SFTData(Dataset): - def __init__(self, train_file, tokenizer, max_len=2048, sample=-1, test = False, seed=0, category="", K=4, dedup=False): - self.data = pd.read_csv(train_file) - random.seed(seed) +class BaseDataset(Dataset): + def __init__(self, tokenizer=None, max_len=2048, test=False, category="", dedup=False, seed=None): + super().__init__() + self.data = None + self.inputs = None + + if tokenizer is not None: + self.tokenizer = Tokenizer(tokenizer) + if seed is not None: + random.seed(seed) - if sample > 0: - self.data = self.data.sample(sample, random_state=seed) - self.tokenizer = Tokenizer(tokenizer) self.test = test self.max_len = max_len self.category = category - # self.K = K self.dedup = dedup + + def __len__(self): + return len(self.data) + + def get_inputs(self): + inputs = [] + for i in tqdm(range(len(self.data))): + inputs.append(self.pre(i)) + self.inputs = inputs + + def get_all(self): + temp = [] + for i in range(len(self.data)): + temp.append(self.get_history(self.data.iloc[i])) + return temp + + def get_inputs_list(self): + return self.inputs + + def __getitem__(self, idx): + return self.inputs[idx] + + def pre(self, idx): + raise NotImplementedError(None) + + def get_history(self, row): + raise {} + + def generate_prompt(self, data_point): + return f"""### User Input: +{data_point["input"]} + +### Response:\n{data_point["output"]}""" + + +class CSVBaseDataset(BaseDataset): + def __init__(self, train_file, sample=-1, seed=0, max_len=2048, category="", dedup=False, tokenizer=None, test=False): + super().__init__(tokenizer, max_len, test, category, dedup, seed) + + self.data = pd.read_csv(train_file) + + if sample > 0: + self.data = self.data.sample(sample, random_state=seed) + + +class JSONBaseDataset(BaseDataset): + def __init__(self, item_file=None, index_file=None, tokenizer=None, max_len=2048, test=False, category="", dedup=False, seed=None): + super().__init__(tokenizer, max_len, test, category, dedup, seed) + + # Load item features and indices if files are provided + with open(item_file, 'r') as f: + self.item_feat = json.load(f) + with open(index_file, 'r') as f: + self.indices = json.load(f) + + +class SFTData(CSVBaseDataset): + def __init__(self, train_file, tokenizer, max_len=2048, sample=-1, test = False, seed=0, category="", K=4, dedup=False): + super().__init__(train_file, sample, seed, max_len, category, dedup, tokenizer, test) + self.instructs = [ f"Given a list of {category} the user recetenly enjoy, please write a new {category} that the user may bought", f"Considering the {category} that has recently captured the user's interest, kindly create a compilation of other {category} that the user might have played prior to this.", @@ -61,8 +123,6 @@ def __init__(self, train_file, tokenizer, max_len=2048, sample=-1, test = False, f"In relation to the user's recent entertainment with a given {category}, it would be appreciated if you could curate a list of {category} that might form part of the user's previous gaming history." ] self.get_inputs() - def __len__(self): - return len(self.data) def generate_example_prompt(self, data_point): @@ -71,14 +131,6 @@ def generate_example_prompt(self, data_point): ### Response:\n{data_point["output"]} """ - - def generate_prompt(self, data_point): - return f"""### User Input: -{data_point["input"]} - -### Response:\n{data_point["output"]}""" - - def get_history(self, row): row['history_item_title'] = eval(row['history_item_title']) @@ -145,41 +197,12 @@ def pre(self, idx): "labels": labels[-self.max_len:], } - - - - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.data))): - inputs.append(self.pre(i)) - # print(inputs[-1]) - - self.inputs = inputs - - - def get_all(self): - temp = [] - for i in range(len(self.data)): - temp.append(self.get_history(self.data.iloc[i])) - return temp - - def get_inputs_list(self): - return self.inputs - - def __getitem__(self, idx): - return self.inputs[idx] -class D3Dataset(Dataset): +class D3Dataset(CSVBaseDataset): def __init__(self, train_file, max_len=2048, sample=-1, seed=0, category="", dedup=False): - self.data = pd.read_csv(train_file) - random.seed(seed) - - if sample > 0: - self.data = self.data.sample(sample, random_state=seed) - self.category = category - self.dedup = dedup + super().__init__(train_file, sample, seed, max_len, category, dedup, tokenizer=None, test=False) + self.prompt2history = {} self.history2target = {} self.instructs = [ @@ -195,21 +218,7 @@ def __init__(self, train_file, max_len=2048, sample=-1, seed=0, category="", ded f"Bearing in mind the {category} that the user has recently been enthralled by, please construct a catalog of other {category} that the user potentially partook in beforehand.", f"In relation to the user's recent entertainment with a given {category}, it would be appreciated if you could curate a list of {category} that might form part of the user's previous gaming history." ] - self.get_inputs() - - - def __len__(self): - return len(self.data) - - - - - def generate_prompt(self, data_point): - return f"""### User Input: -{data_point["input"]} - -### Response:\n{data_point["output"]}""" - + self.get_inputs() def get_history(self, row): row['history_item_title'] = eval(row['history_item_title']) @@ -248,41 +257,13 @@ def pre(self, idx): "prompt": instruction + prompt, "completion": target_item, } - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.data))): - inputs.append(self.pre(i)) - - self.inputs = inputs - - - def get_all(self): - temp = [] - for i in range(len(self.data)): - temp.append(self.get_history(self.data.iloc[i])) - return temp - - def get_inputs_list(self): - return self.inputs - - def __getitem__(self, idx): - return self.inputs[idx] -class EvalD3Dataset(Dataset): +class EvalD3Dataset(CSVBaseDataset): def __init__(self, train_file, tokenizer, max_len=2048, sample=-1, test = False, seed=0, category="", K=4, dedup=False): - self.data = pd.read_csv(train_file) - random.seed(seed) - - if sample > 0: - self.data = self.data.sample(sample, random_state=seed) - self.tokenizer = Tokenizer(tokenizer) - self.test = test - self.max_len = max_len - self.category = category - self.dedup = dedup + super().__init__(train_file, sample, seed, max_len, category, dedup, tokenizer, test) + self.instructs = [ f"Given a list of {category} the user recetenly enjoy, please write a new {category} that the user may bought", f"Considering the {category} that has recently captured the user's interest, kindly create a compilation of other {category} that the user might have played prior to this.", @@ -297,9 +278,6 @@ def __init__(self, train_file, tokenizer, max_len=2048, sample=-1, test = False, f"In relation to the user's recent entertainment with a given {category}, it would be appreciated if you could curate a list of {category} that might form part of the user's previous gaming history." ] self.get_inputs() - def __len__(self): - return len(self.data) - def generate_example_prompt(self, data_point): return f"""### Example {data_point["idx"]}: @@ -307,14 +285,6 @@ def generate_example_prompt(self, data_point): ### Response:\n{data_point["output"]} """ - - def generate_prompt(self, data_point): - return f"""### User Input: -{data_point["input"]} - -### Response:\n{data_point["output"]}""" - - def get_history(self, row): row['history_item_title'] = eval(row['history_item_title']) L = len(row['history_item_title']) @@ -379,53 +349,15 @@ def pre(self, idx): "labels": labels[-self.max_len:], } - - - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.data))): - inputs.append(self.pre(i)) - - self.inputs = inputs - - - def get_all(self): - temp = [] - for i in range(len(self.data)): - temp.append(self.get_history(self.data.iloc[i])) - return temp - - def get_inputs_list(self): - return self.inputs - - def __getitem__(self, idx): - return self.inputs[idx] - -class SidDataset(Dataset): +class SidDataset(CSVBaseDataset): def __init__(self, train_file, max_len=2048, sample=-1, seed=0, category="", dedup=False): - self.data = pd.read_csv(train_file) - random.seed(seed) - - if sample > 0: - self.data = self.data.sample(sample, random_state=seed) - self.category = category - self.dedup = dedup + super().__init__(train_file, sample, seed, max_len, category, dedup, tokenizer=None, test=False) + self.prompt2history = {} self.history2target = {} self.get_inputs() - - - def __len__(self): - return len(self.data) - - def generate_prompt(self, data_point): - return f"""### User Input: -{data_point["input"]} - -### Response:\n{data_point["output"]}""" def get_history(self, row): row['history_item_sid'] = eval(row['history_item_sid']) @@ -460,48 +392,13 @@ def pre(self, idx): "completion": target_item, } - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.data))): - inputs.append(self.pre(i)) - - self.inputs = inputs - - def get_all(self): - temp = [] - for i in range(len(self.data)): - temp.append(self.get_history(self.data.iloc[i])) - return temp - - def get_inputs_list(self): - return self.inputs - def __getitem__(self, idx): - return self.inputs[idx] -class SidSFTDataset(Dataset): +class SidSFTDataset(CSVBaseDataset): def __init__(self, train_file, tokenizer, max_len=2048, sample=-1, test=False, seed=0, category="", K=4, dedup=False): - self.data = pd.read_csv(train_file) - random.seed(seed) - - if sample > 0: - self.data = self.data.sample(sample, random_state=seed) - self.tokenizer = Tokenizer(tokenizer) - self.test = test - self.max_len = max_len - self.category = category - self.dedup = dedup - self.get_inputs() - - def __len__(self): - return len(self.data) - - def generate_prompt(self, data_point): - return f"""### User Input: -{data_point["input"]} + super().__init__(train_file, sample, seed, max_len, category, dedup, tokenizer, test) -### Response:\n{data_point["output"]}""" + self.get_inputs() def get_history(self, row): row['history_item_sid'] = eval(row['history_item_sid']) @@ -567,40 +464,12 @@ def pre(self, idx): "attention_mask": attention_mask[-self.max_len:], "labels": labels[-self.max_len:], } - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.data))): - inputs.append(self.pre(i)) - - self.inputs = inputs - - def get_all(self): - temp = [] - for i in range(len(self.data)): - temp.append(self.get_history(self.data.iloc[i])) - return temp - - def get_inputs_list(self): - return self.inputs - - def __getitem__(self, idx): - return self.inputs[idx] -class SidSFTDataset_GPR(Dataset): +class SidSFTDataset_GPR(CSVBaseDataset): def __init__(self, train_file, tokenizer, max_len=2048, sample=-1, test=False, seed=0, category="", K=4, dedup=False): - self.data = pd.read_csv(train_file) - random.seed(seed) - - if sample > 0: - self.data = self.data.sample(sample, random_state=seed) - self.tokenizer = Tokenizer(tokenizer) - self.test = test - self.max_len = max_len - self.category = category - self.dedup = dedup - + super().__init__(train_file, sample, seed, max_len, category, dedup, tokenizer, test) + # Try to load features from standard location try: with open(f'data/{category}/{category}.user.json', 'r') as f: @@ -621,15 +490,6 @@ def __init__(self, train_file, tokenizer, max_len=2048, sample=-1, test=False, s self.item_features = {} self.get_inputs() - - def __len__(self): - return len(self.data) - - def generate_prompt(self, data_point): - return f"""### User Input: -{data_point["input"]} - -### Response:\n{data_point["output"]}""" def get_history(self, row): row['history_item_sid'] = eval(row['history_item_sid']) @@ -732,44 +592,14 @@ def pre(self, idx): "labels": labels[-self.max_len:], "final_value": final_value } - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.data))): - inputs.append(self.pre(i)) - - self.inputs = inputs - - def get_all(self): - temp = [] - for i in range(len(self.data)): - temp.append(self.get_history(self.data.iloc[i])) - return temp - - def get_inputs_list(self): - return self.inputs - - def __getitem__(self, idx): - return self.inputs[idx] -class EvalSidDataset(Dataset): +class EvalSidDataset(CSVBaseDataset): def __init__(self, train_file, tokenizer, max_len=2048, sample=-1, test = False, seed=0, category="", K=4, dedup=False): - self.data = pd.read_csv(train_file) - random.seed(seed) - - if sample > 0: - self.data = self.data.sample(sample, random_state=seed) - self.tokenizer = Tokenizer(tokenizer) - self.test = test - self.max_len = max_len - self.category = category - self.dedup = dedup - self.get_inputs() - def __len__(self): - return len(self.data) + super().__init__(train_file, sample, seed, max_len, category, dedup, tokenizer, test) + self.get_inputs() def generate_example_prompt(self, data_point): return f"""### Example {data_point["idx"]}: @@ -777,12 +607,6 @@ def generate_example_prompt(self, data_point): ### Response:\n{data_point["output"]} """ - - def generate_prompt(self, data_point): - return f"""### User Input: -{data_point["input"]} - -### Response:\n{data_point["output"]}""" def get_history(self, row): row['history_item_sid'] = eval(row['history_item_sid']) @@ -848,32 +672,10 @@ def pre(self, idx): "attention_mask": attention_mask[-self.max_len:], "labels": labels[-self.max_len:], - } - - - - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.data))): - inputs.append(self.pre(i)) - - self.inputs = inputs - - - def get_all(self): - temp = [] - for i in range(len(self.data)): - temp.append(self.get_history(self.data.iloc[i])) - return temp - - def get_inputs_list(self): - return self.inputs + } - def __getitem__(self, idx): - return self.inputs[idx] -class SidItemFeatDataset(Dataset): +class SidItemFeatDataset(JSONBaseDataset): def __init__(self, item_file, index_file, tokenizer=None, max_len=2048, sample=-1, test=False, seed=0, category=""): """ Dataset for sid2title and title2sid tasks. @@ -887,19 +689,8 @@ def __init__(self, item_file, index_file, tokenizer=None, max_len=2048, sample=- test: Whether this is test mode seed: Random seed category: Category name for prompts - """ - random.seed(seed) - - # Load item features and indices - with open(item_file, 'r') as f: - self.item_feat = json.load(f) - with open(index_file, 'r') as f: - self.indices = json.load(f) - - self.tokenizer = Tokenizer(tokenizer) if tokenizer is not None else None - self.test = test - self.max_len = max_len - self.category = category + """ + super().__init__(item_file=item_file, index_file=index_file, tokenizer=tokenizer, max_len=max_len, test=test, category=category, dedup=False, seed=seed) # Build sid2title and title2sid mappings self.sid2title = {} @@ -938,10 +729,7 @@ def __init__(self, item_file, index_file, tokenizer=None, max_len=2048, sample=- if self.tokenizer is not None: self.get_inputs() - - def __len__(self): - return len(self.data) - + def generate_prompt(self, data_point): if data_point['task'] == 'title2sid': prompt = f"Which item has the title: {data_point['input']}?" @@ -996,23 +784,9 @@ def pre(self, idx): "attention_mask": attention_mask[-self.max_len:], "labels": labels[-self.max_len:], } - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.data))): - inputs.append(self.pre(i)) - self.inputs = inputs - - def get_inputs_list(self): - return self.inputs if hasattr(self, 'inputs') else [self.pre(i) for i in range(len(self))] - - def __getitem__(self, idx): - if hasattr(self, 'inputs'): - return self.inputs[idx] - return self.pre(idx) -class RLTitle2SidDataset(Dataset): +class RLTitle2SidDataset(JSONBaseDataset): def __init__(self, item_file, index_file, sample=-1, seed=0, category="", dedup=False): """ RL-specific dataset for title2sid and description2sid tasks. @@ -1027,16 +801,9 @@ def __init__(self, item_file, index_file, sample=-1, seed=0, category="", dedup= category: Category name for prompts dedup: Whether to filter duplicate items """ - random.seed(seed) - - # Load item features and indices - with open(item_file, 'r') as f: - self.item_feat = json.load(f) - with open(index_file, 'r') as f: - self.indices = json.load(f) + super().__init__(item_file, index_file, tokenizer=None, max_len=1024, test=False, category=category, dedup=dedup, seed=seed) + - self.category = category - self.dedup = dedup self.prompt2history = {} self.history2target = {} @@ -1090,9 +857,7 @@ def __init__(self, item_file, index_file, sample=-1, seed=0, category="", dedup= self.data = random.sample(self.data, sample) self.get_inputs() - - def __len__(self): - return len(self.data) + def generate_prompt(self, data_point): if data_point['task'] == 'title2sid': @@ -1120,27 +885,9 @@ def pre(self, idx): "completion": target_item, } - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.data))): - inputs.append(self.pre(i)) - self.inputs = inputs - - def get_all(self): - temp = [] - for i in range(len(self.data)): - temp.append(self.data[i]) - return temp - - def get_inputs_list(self): - return self.inputs - - def __getitem__(self, idx): - return self.inputs[idx] -class RLSeqTitle2SidDataset(Dataset): +class RLSeqTitle2SidDataset(CSVBaseDataset): def __init__(self, train_file, sample=-1, seed=0, category="", dedup=False): """ RL-specific dataset for sequential recommendation using title sequences. @@ -1153,23 +900,13 @@ def __init__(self, train_file, sample=-1, seed=0, category="", dedup=False): category: Category name for prompts dedup: Whether to filter duplicate items """ - random.seed(seed) - - # Load sequence data - self.data = pd.read_csv(train_file) - if sample > 0: - self.data = self.data.sample(sample, random_state=seed) - - self.category = category - self.dedup = dedup + super().__init__(train_file, sample, seed, max_len=1024, category=category, dedup=dedup, tokenizer=None, test=False) + self.prompt2history = {} self.history2target = {} self.get_inputs() - def __len__(self): - return len(self.data) - def generate_prompt(self, inter_titles): return f"Given the title sequence of user historical interactive items: {inter_titles}, can you recommend a suitable next item for the user?" @@ -1227,32 +964,9 @@ def pre(self, idx): "completion": target, } - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.data))): - result = self.pre(i) - if result is not None: # Skip None results from deduplication - inputs.append(result) - self.inputs = inputs - - def get_all(self): - temp = [] - for i in range(len(self.data)): - temp.append(self.get_history(self.data.iloc[i])) - return temp - - def get_inputs_list(self): - return self.inputs if hasattr(self, 'inputs') else [] - - def __getitem__(self, idx): - if hasattr(self, 'inputs'): - return self.inputs[idx] - result = self.pre(idx) - return result if result is not None else {"prompt": "", "completion": ""} - -class RLSid2TitleDataset(Dataset): + +class RLSid2TitleDataset(JSONBaseDataset): def __init__(self, item_file, index_file, sample=-1, seed=0, category="", dedup=False): """ RL-specific dataset for sid2title tasks. @@ -1266,16 +980,8 @@ def __init__(self, item_file, index_file, sample=-1, seed=0, category="", dedup= category: Category name for prompts dedup: Whether to filter duplicate items """ - random.seed(seed) - - # Load item features and indices - with open(item_file, 'r') as f: - self.item_feat = json.load(f) - with open(index_file, 'r') as f: - self.indices = json.load(f) - - self.category = category - self.dedup = dedup + super().__init__(item_file, index_file, tokenizer=None, max_len=1024, test=False, category=category, dedup=dedup, seed=seed) + self.prompt2history = {} self.history2target = {} @@ -1306,9 +1012,6 @@ def __init__(self, item_file, index_file, sample=-1, seed=0, category="", dedup= self.get_inputs() - def __len__(self): - return len(self.data) - def generate_prompt(self, data_point): prompt = f'What is the title of item "{data_point["input"]}"?' response = data_point['output'] @@ -1331,27 +1034,9 @@ def pre(self, idx): "completion": target_item, } - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.data))): - inputs.append(self.pre(i)) - self.inputs = inputs - - def get_all(self): - temp = [] - for i in range(len(self.data)): - temp.append(self.data[i]) - return temp - - def get_inputs_list(self): - return self.inputs - - def __getitem__(self, idx): - return self.inputs[idx] - -class RLSidhis2TitleDataset(Dataset): + +class RLSidhis2TitleDataset(BaseDataset): def __init__(self, train_file, item_file, index_file, sample=-1, seed=0, category="", dedup=False): """ RL-specific dataset for sequential recommendation using semantic IDs in history and outputting item titles. @@ -1366,21 +1051,19 @@ def __init__(self, train_file, item_file, index_file, sample=-1, seed=0, categor category: Category name for prompts dedup: Whether to filter duplicate items """ - random.seed(seed) - - # Load sequence data + BaseDataset.__init__(self, tokenizer=None, max_len=1024, test=False, category=category, dedup=dedup, seed=seed) + + # Initialize CSV part self.data = pd.read_csv(train_file) if sample > 0: self.data = self.data.sample(sample, random_state=seed) - # Load item features and indices + # Initialize JSON part with open(item_file, 'r') as f: self.item_feat = json.load(f) with open(index_file, 'r') as f: self.indices = json.load(f) - - self.category = category - self.dedup = dedup + self.prompt2history = {} self.history2target = {} @@ -1390,15 +1073,6 @@ def __init__(self, train_file, item_file, index_file, sample=-1, seed=0, categor self.id2title[item_id] = features['title'] self.get_inputs() - - def __len__(self): - return len(self.data) - - def generate_prompt(self, data_point): - return f"""### User Input: -{data_point["input"]} - -### Response:\n{data_point["output"]}""" def get_history(self, row): row['history_item_sid'] = eval(row['history_item_sid']) @@ -1447,32 +1121,9 @@ def pre(self, idx): "completion": target_item, } - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.data))): - result = self.pre(i) - if result is not None: # Skip None results from deduplication - inputs.append(result) - self.inputs = inputs - - def get_all(self): - temp = [] - for i in range(len(self.data)): - temp.append(self.get_history(self.data.iloc[i])) - return temp - - def get_inputs_list(self): - return self.inputs if hasattr(self, 'inputs') else [] - def __getitem__(self, idx): - if hasattr(self, 'inputs'): - return self.inputs[idx] - result = self.pre(idx) - return result if result is not None else {"prompt": "", "completion": ""} - -class FusionSeqRecDataset(Dataset): +class FusionSeqRecDataset(BaseDataset): def __init__(self, train_file, item_file, index_file, tokenizer, max_len=2048, sample=-1, test=False, seed=0, category="", dedup=False): """ Fusion dataset combining sequence recommendation with item features. @@ -1490,25 +1141,19 @@ def __init__(self, train_file, item_file, index_file, tokenizer, max_len=2048, s category: Category name for prompts dedup: Whether to filter duplicate items """ - random.seed(seed) + BaseDataset.__init__(self, tokenizer, max_len, test, category, dedup, seed) - # Load sequence data + # Initialize CSV part self.data = pd.read_csv(train_file) if sample > 0: self.data = self.data.sample(sample, random_state=seed) - # Load item features and indices + # Initialize JSON part with open(item_file, 'r') as f: self.item_feat = json.load(f) with open(index_file, 'r') as f: self.indices = json.load(f) - - self.tokenizer = Tokenizer(tokenizer) - self.test = test - self.max_len = max_len - self.category = category - self.dedup = dedup - + # Build sid2title and sid2description mappings self.sid2title = {} self.sid2description = {} @@ -1532,7 +1177,7 @@ def __init__(self, train_file, item_file, index_file, tokenizer, max_len=2048, s # print("self.sid2title: ", self.sid2title) # print("self.sid2description: ", self.sid2description) self.get_inputs() - + def _process_description(self, description, title): """ Process description according to the requirements: @@ -1581,9 +1226,6 @@ def _process_description(self, description, title): # Empty list, use title return title - def __len__(self): - return len(self.data) - def generate_prompt_title(self, history): return f"The user has sequentially interacted with items {history}. Can you recommend the next item for him? Tell me the title of the item" @@ -1685,26 +1327,9 @@ def pre(self, idx): "attention_mask": attention_mask[-self.max_len:], "labels": labels[-self.max_len:], } - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.data))): - result = self.pre(i) - if result is not None: # Skip None results from deduplication - inputs.append(result) - self.inputs = inputs - - def get_inputs_list(self): - return self.inputs if hasattr(self, 'inputs') else [] - - def __getitem__(self, idx): - if hasattr(self, 'inputs'): - return self.inputs[idx] - return self.pre(idx) - -class TitleHistory2SidSFTDataset(Dataset): +class TitleHistory2SidSFTDataset(BaseDataset): def __init__(self, train_file, item_file, index_file, tokenizer, max_len=2048, sample=-1, test=False, seed=0, category="", dedup=False): """ SFT dataset that uses item titles in user history to predict next item's semantic ID. @@ -1721,24 +1346,18 @@ def __init__(self, train_file, item_file, index_file, tokenizer, max_len=2048, s category: Category name for prompts dedup: Whether to filter duplicate items """ - random.seed(seed) - - # Load sequence data + BaseDataset.__init__(self, tokenizer, max_len, test, category, dedup, seed) + # Initialize CSV part self.data = pd.read_csv(train_file) if sample > 0: self.data = self.data.sample(sample, random_state=seed) - # Load item features and indices + # Initialize JSON part with open(item_file, 'r') as f: self.item_feat = json.load(f) with open(index_file, 'r') as f: self.indices = json.load(f) - - self.tokenizer = Tokenizer(tokenizer) - self.test = test - self.max_len = max_len - self.category = category - self.dedup = dedup + # Build item_id to semantic ID mapping self.id2sid = {} @@ -1749,15 +1368,6 @@ def __init__(self, train_file, item_file, index_file, tokenizer, max_len=2048, s self.get_inputs() - def __len__(self): - return len(self.data) - - def generate_prompt(self, data_point): - return f"""### User Input: -{data_point["input"]} - -### Response:\n{data_point["output"]}""" - def get_history(self, row): """Extract user history from title sequence and target semantic ID""" # Parse history_item_title field @@ -1833,32 +1443,9 @@ def pre(self, idx): "attention_mask": attention_mask[-self.max_len:], "labels": labels[-self.max_len:], } - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.data))): - result = self.pre(i) - if result is not None: # Skip None results from deduplication - inputs.append(result) - self.inputs = inputs - - def get_all(self): - temp = [] - for i in range(len(self.data)): - temp.append(self.get_history(self.data.iloc[i])) - return temp - - def get_inputs_list(self): - return self.inputs if hasattr(self, 'inputs') else [] - - def __getitem__(self, idx): - if hasattr(self, 'inputs'): - return self.inputs[idx] - result = self.pre(idx) - return result if result is not None else {"input_ids": [], "attention_mask": [], "labels": []} -class PreferenceSFTDataset(Dataset): +class PreferenceSFTDataset(BaseDataset): def __init__(self, user_preference_file, index_file, tokenizer, max_len=2048, sample=-1, test=False, seed=0, category="", dedup=False): """ SFT dataset that uses user interaction history and preferences from preference file. @@ -1874,8 +1461,7 @@ def __init__(self, user_preference_file, index_file, tokenizer, max_len=2048, sa category: Category name for prompts dedup: Whether to filter duplicate items """ - random.seed(seed) - + super().__init__(tokenizer, max_len, test, category, dedup, seed) # Load user preferences - handle both JSON and JSONL formats with open(user_preference_file, 'r') as f: try: @@ -1914,17 +1500,11 @@ def __init__(self, user_preference_file, index_file, tokenizer, max_len=2048, sa with open(index_file, 'r') as f: self.indices = json.load(f) - self.tokenizer = Tokenizer(tokenizer) - self.test = test - self.max_len = max_len - self.category = category - self.dedup = dedup - # Find users with preferences and prepare data - self.matched_data = self._prepare_preference_data() + self.data = self._prepare_preference_data() - if sample > 0 and sample < len(self.matched_data): - self.matched_data = random.sample(self.matched_data, sample) + if sample > 0 and sample < len(self.data): + self.data = random.sample(self.data, sample) self.get_inputs() @@ -1955,9 +1535,6 @@ def _prepare_preference_data(self): return matched_data - def __len__(self): - return len(self.matched_data) - def _convert_to_semantic_ids(self, item_ids): """Convert item IDs to semantic ID format using index.json""" semantic_ids = [] @@ -1977,7 +1554,7 @@ def _convert_to_semantic_ids(self, item_ids): return semantic_ids - def get_history_and_preference(self, row_data): + def get_history(self, row_data): """Extract and format user history and preference""" # Get input history item IDs (all but last item) and convert to semantic IDs input_history_ids = row_data['input_history'] @@ -2004,12 +1581,6 @@ def get_history_and_preference(self, row_data): # print(result) return result - def generate_prompt(self, data_point): - return f"""### User Input: -{data_point["input"]} - -### Response:\n{data_point["output"]}""" - def pre(self, idx): instruction = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. @@ -2019,8 +1590,8 @@ def pre(self, idx): """ tokens = self.tokenizer.encode(instruction, bos=True, eos=False) - row_data = self.matched_data[idx] - history_and_pref = self.get_history_and_preference(row_data) + row_data = self.data[idx] + history_and_pref = self.get_history(row_data) # Skip empty histories or missing targets if not history_and_pref['history_semantic_ids'] or not history_and_pref['target_sid']: @@ -2053,32 +1624,9 @@ def pre(self, idx): "attention_mask": attention_mask[-self.max_len:], "labels": labels[-self.max_len:], } - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.matched_data))): - result = self.pre(i) - if result is not None: # Skip None results from empty histories - inputs.append(result) - self.inputs = inputs - - def get_all(self): - temp = [] - for i in range(len(self.matched_data)): - temp.append(self.get_history_and_preference(self.matched_data[i])) - return temp - - def get_inputs_list(self): - return self.inputs if hasattr(self, 'inputs') else [] - - def __getitem__(self, idx): - if hasattr(self, 'inputs'): - return self.inputs[idx] - result = self.pre(idx) - return result if result is not None else {"input_ids": [], "attention_mask": [], "labels": []} -class UserPreference2sidSFTDataset(Dataset): +class UserPreference2sidSFTDataset(BaseDataset): def __init__(self, user_preference_file, index_file, tokenizer, max_len=2048, sample=-1, test=False, seed=0, category="", dedup=False): """ SFT dataset that uses user interaction history with preferences to predict next item's semantic ID. @@ -2095,8 +1643,8 @@ def __init__(self, user_preference_file, index_file, tokenizer, max_len=2048, sa category: Category name for prompts dedup: Whether to filter duplicate items """ - random.seed(seed) - + super().__init__(tokenizer, max_len, test, category, dedup, seed) + # Load user preferences - handle both JSON and JSONL formats with open(user_preference_file, 'r') as f: try: @@ -2135,17 +1683,11 @@ def __init__(self, user_preference_file, index_file, tokenizer, max_len=2048, sa with open(index_file, 'r') as f: self.indices = json.load(f) - self.tokenizer = Tokenizer(tokenizer) - self.test = test - self.max_len = max_len - self.category = category - self.dedup = dedup - # Prepare training data from preference file interaction histories - self.matched_data = self._prepare_sequence_data() + self.data = self._prepare_sequence_data() - if sample > 0 and sample < len(self.matched_data): - self.matched_data = random.sample(self.matched_data, sample) + if sample > 0 and sample < len(self.data): + self.data = random.sample(self.data, sample) self.get_inputs() @@ -2176,9 +1718,6 @@ def _prepare_sequence_data(self): return matched_data - def __len__(self): - return len(self.matched_data) - def _convert_to_semantic_ids(self, item_ids): """Convert item IDs to semantic ID format using index.json""" semantic_ids = [] @@ -2223,12 +1762,6 @@ def get_input_and_target(self, row_data): "target_sid": target_sid } - def generate_prompt(self, data_point): - return f"""### User Input: -{data_point["input"]} - -### Response:\n{data_point["output"]}""" - def pre(self, idx): instruction = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. @@ -2238,7 +1771,7 @@ def pre(self, idx): """ tokens = self.tokenizer.encode(instruction, bos=True, eos=False) - row_data = self.matched_data[idx] + row_data = self.data[idx] input_and_target = self.get_input_and_target(row_data) # Skip empty histories or missing targets @@ -2272,26 +1805,3 @@ def pre(self, idx): "attention_mask": attention_mask[-self.max_len:], "labels": labels[-self.max_len:], } - - def get_inputs(self): - inputs = [] - for i in tqdm(range(len(self.matched_data))): - result = self.pre(i) - if result is not None: # Skip None results from empty histories or missing targets - inputs.append(result) - self.inputs = inputs - - def get_all(self): - temp = [] - for i in range(len(self.matched_data)): - temp.append(self.get_input_and_target(self.matched_data[i])) - return temp - - def get_inputs_list(self): - return self.inputs if hasattr(self, 'inputs') else [] - - def __getitem__(self, idx): - if hasattr(self, 'inputs'): - return self.inputs[idx] - result = self.pre(idx) - return result if result is not None else {"input_ids": [], "attention_mask": [], "labels": []} diff --git a/data_test.py b/data_test.py new file mode 100644 index 0000000..de6b713 --- /dev/null +++ b/data_test.py @@ -0,0 +1,299 @@ +import unittest +import tempfile +import os +import pandas as pd +import json +import sys +import random +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from data import ( + SFTData, D3Dataset, EvalD3Dataset, EvalSidDataset, + SidDataset, SidSFTDataset, SidItemFeatDataset, RLTitle2SidDataset, + RLSid2TitleDataset, RLSidhis2TitleDataset, FusionSeqRecDataset, + TitleHistory2SidSFTDataset, PreferenceSFTDataset, UserPreference2sidSFTDataset +) +class MockTokenizer: + def __init__(self): + self.pad_token_id = 0 + self.eos_token_id = 3 + self.bos_token_id = 2 + + def encode(self, text, bos=False, eos=False): + # Simple mock encoding - just return list of integers based on text length + tokens = list(range(10, 10 + min(len(text), 50))) # Limit to 50 tokens max + if bos: + tokens = [self.bos_token_id] + tokens + if eos: + tokens = tokens + [self.eos_token_id] + return tokens + +def create_minimal_csv(file_path, data): + """Helper to create minimal CSV files for testing""" + df = pd.DataFrame(data) + df.to_csv(file_path, index=False) + +class TestDataModule(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.tokenizer = MockTokenizer() + + # Create temporary directory for test files + cls.temp_dir = tempfile.TemporaryDirectory() + cls.temp_path = cls.temp_dir.name + + # Create sample CSV data + cls.csv_data = { + 'history_item_title': ["['Item A', 'Item B']", "['Item C', 'Item D']"], + 'item_title': ['Item E', 'Item F'], + 'history_item_id': ["['1', '2']", "['3', '4']"], + 'item_id': ['5', '6'], + 'history_item_sid': ["['SID1', 'SID2']", "['SID3', 'SID4']"], + 'item_sid': ['SID5', 'SID6'], + 'user_id_original_str': ['user1', 'user2'], + 'e_token': ['[CTX_HOMEPAGE]', '[CTX_SEARCH]'] + } + cls.csv_file = os.path.join(cls.temp_path, 'test_data.csv') + create_minimal_csv(cls.csv_file, cls.csv_data) + + # Create sample item features JSON + cls.item_features = { + '5': {'title': 'Item E', 'description': 'Description of Item E', 'item_type': 'O'}, + '6': {'title': 'Item F', 'description': 'Description of Item F', 'item_type': 'I'} + } + cls.item_file = os.path.join(cls.temp_path, 'test.item.json') + with open(cls.item_file, 'w') as f: + json.dump(cls.item_features, f) + + # Create sample indices JSON + cls.indices = { + '5': ['SID5_1', 'SID5_2', 'SID5_3'], + '6': ['SID6_1', 'SID6_2', 'SID6_3'] + } + cls.index_file = os.path.join(cls.temp_path, 'test.index.json') + with open(cls.index_file, 'w') as f: + json.dump(cls.indices, f) + + # Create sample user preference JSON + cls.user_preference_data = [ + { + 'user': 'user1', + 'user_preference': 'Likes action games', + 'context': { + 'history_items': ['1', '2'], + 'target_item': '5' + }, + 'split': 'train' + }, + { + 'user': 'user2', + 'user_preference': 'Prefers strategy games', + 'context': { + 'history_items': ['3', '4'], + 'target_item': '6' + }, + 'split': 'train' + } + ] + cls.preference_file = os.path.join(cls.temp_path, 'test_preference.json') + with open(cls.preference_file, 'w') as f: + json.dump(cls.user_preference_data, f) + + @classmethod + def tearDownClass(cls): + # Cleanup temporary directory + cls.temp_dir.cleanup() + + def test_SFTData_initialization(self): + """Test SFTData initialization""" + dataset = SFTData( + train_file=self.csv_file, + tokenizer=self.tokenizer, + max_len=128, + sample=1, + seed=0, + category="games" + ) + self.assertEqual(len(dataset), 1) + self.assertTrue(hasattr(dataset, 'inputs')) + + def test_D3Dataset_initialization(self): + """Test D3Dataset initialization""" + dataset = D3Dataset( + train_file=self.csv_file, + max_len=128, + sample=1, + seed=0, + category="games" + ) + self.assertEqual(len(dataset), 1) + self.assertTrue(hasattr(dataset, 'inputs')) + + def test_EvalD3Dataset_initialization(self): + """Test EvalD3Dataset initialization""" + dataset = EvalD3Dataset( + train_file=self.csv_file, + tokenizer=self.tokenizer, + max_len=128, + sample=1, + seed=0, + category="games" + ) + self.assertEqual(len(dataset), 1) + self.assertTrue(hasattr(dataset, 'inputs')) + + def test_SidDataset_initialization(self): + """Test SidDataset initialization""" + dataset = SidDataset( + train_file=self.csv_file, + max_len=128, + sample=1, + seed=0, + category="games" + ) + self.assertEqual(len(dataset), 1) + self.assertTrue(hasattr(dataset, 'inputs')) + + def test_SidSFTDataset_initialization(self): + """Test SidSFTDataset initialization""" + dataset = SidSFTDataset( + train_file=self.csv_file, + tokenizer=self.tokenizer, + max_len=128, + sample=1, + seed=0, + category="games" + ) + self.assertEqual(len(dataset), 1) + self.assertTrue(hasattr(dataset, 'inputs')) + + def test_SFTData_initialization(self): + """Test SFTData initialization""" + dataset = SFTData( + train_file=self.csv_file, + tokenizer=self.tokenizer, + max_len=128, + sample=1, + seed=0, + category="games" + ) + self.assertEqual(len(dataset), 1) + self.assertTrue(hasattr(dataset, 'inputs')) + + def test_SidItemFeatDataset_initialization(self): + """Test SidItemFeatDataset initialization""" + dataset = SidItemFeatDataset( + item_file=self.item_file, + index_file=self.index_file, + tokenizer=self.tokenizer, + max_len=128, + sample=2, + seed=0 + ) + self.assertGreaterEqual(len(dataset), 2) # Should have at least 2 samples (sid2title and title2sid) + self.assertTrue(hasattr(dataset, 'inputs')) + + def test_EvalSidDataset_initialization(self): + """Test SidItemFeatDataset initialization""" + dataset = EvalSidDataset( + train_file=self.csv_file, + tokenizer=self.tokenizer, + max_len=128, + sample=1, + seed=0, + category="games" + ) + + self.assertEqual(len(dataset), 1) + self.assertTrue(hasattr(dataset, 'inputs')) + + def test_RLTitle2SidDataset_initialization(self): + """Test RLTitle2SidDataset initialization""" + dataset = RLTitle2SidDataset( + item_file=self.item_file, + index_file=self.index_file, + sample=2, + seed=0 + ) + self.assertGreaterEqual(len(dataset), 2) # Should have at least 2 samples + self.assertTrue(hasattr(dataset, 'inputs')) + + def test_RLSid2TitleDataset_initialization(self): + """Test RLSid2TitleDataset initialization""" + dataset = RLSid2TitleDataset( + item_file=self.item_file, + index_file=self.index_file, + sample=2, + seed=0 + ) + self.assertGreaterEqual(len(dataset), 1) # Should have at least 1 sample + self.assertTrue(hasattr(dataset, 'inputs')) + + def test_RLSidhis2TitleDataset_initialization(self): + """Test RLSidhis2TitleDataset initialization""" + dataset = RLSidhis2TitleDataset( + train_file=self.csv_file, + item_file=self.item_file, + index_file=self.index_file, + sample=1, + seed=0 + ) + self.assertEqual(len(dataset), 1) + self.assertTrue(hasattr(dataset, 'inputs')) + + def test_FusionSeqRecDataset_initialization(self): + """Test FusionSeqRecDataset initialization""" + dataset = FusionSeqRecDataset( + train_file=self.csv_file, + item_file=self.item_file, + index_file=self.index_file, + tokenizer=self.tokenizer, + max_len=128, + sample=1, + seed=0 + ) + self.assertEqual(len(dataset), 1) + self.assertTrue(hasattr(dataset, 'inputs')) + + def test_TitleHistory2SidSFTDataset_initialization(self): + """Test TitleHistory2SidSFTDataset initialization""" + dataset = TitleHistory2SidSFTDataset( + train_file=self.csv_file, + item_file=self.item_file, + index_file=self.index_file, + tokenizer=self.tokenizer, + max_len=128, + sample=1, + seed=0 + ) + self.assertEqual(len(dataset), 1) + self.assertTrue(hasattr(dataset, 'inputs')) + + def test_PreferenceSFTDataset_initialization(self): + """Test PreferenceSFTDataset initialization""" + dataset = PreferenceSFTDataset( + user_preference_file=self.preference_file, + index_file=self.index_file, + tokenizer=self.tokenizer, + max_len=128, + sample=1, + seed=0 + ) + self.assertEqual(len(dataset), 1) + self.assertTrue(hasattr(dataset, 'inputs')) + + def test_UserPreference2sidSFTDataset_initialization(self): + """Test UserPreference2sidSFTDataset initialization""" + dataset = UserPreference2sidSFTDataset( + user_preference_file=self.preference_file, + index_file=self.index_file, + tokenizer=self.tokenizer, + max_len=128, + sample=1, + seed=0 + ) + self.assertEqual(len(dataset), 1) + self.assertTrue(hasattr(dataset, 'inputs')) + +if __name__ == '__main__': + # Run the tests + unittest.main() \ No newline at end of file