Skip to content

Commit e3df758

Browse files
fixing style and quality tests
1 parent 87c73b4 commit e3df758

File tree

2 files changed

+20
-22
lines changed

2 files changed

+20
-22
lines changed

examples/orthogonal_subspace_learning/osf_continual_learning.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def compute_accuracy_scienceqa(model, eval_dataset, tokenizer, data_collator):
8181

8282
# Extract the answer (last letter in the generated text)
8383
# Look for single capital letters A, B, C, D
84-
matches = re.findall(r'\b([A-D])\b', generated_text)
84+
matches = re.findall(r"\b([A-D])\b", generated_text)
8585
pred = matches[-1] if matches else "X"
8686

8787
# Get ground truth (find the label that's not -100)
@@ -124,7 +124,7 @@ def compute_accuracy_numglue(model, eval_dataset, tokenizer, data_collator):
124124
generated_text = tokenizer.decode(outputs[i], skip_special_tokens=True)
125125

126126
# Extract number from generated text
127-
numbers = re.findall(r'-?\d+\.?\d*', generated_text)
127+
numbers = re.findall(r"-?\d+\.?\d*", generated_text)
128128
pred = numbers[-1] if numbers else "-999"
129129

130130
# Get ground truth
@@ -208,7 +208,7 @@ def evaluate_model(model, eval_dataset, data_collator, tokenizer, task_name, tas
208208
else:
209209
accuracy = 0.0
210210

211-
print(f" {task_name}: Loss = {loss:.4f}, Accuracy = {accuracy*100:.2f}%")
211+
print(f" {task_name}: Loss = {loss:.4f}, Accuracy = {accuracy * 100:.2f}%")
212212
return loss, accuracy
213213

214214

@@ -321,7 +321,7 @@ def train_with_osf(
321321
for task_idx, task in enumerate(tasks):
322322
print(f"\n{'=' * 80}")
323323
print(f"TASK {task_idx + 1}: {task['name']}")
324-
print(f"Effective Rank: {task['effective_rank']} (preserving {task['effective_rank']*100:.0f}%)")
324+
print(f"Effective Rank: {task['effective_rank']} (preserving {task['effective_rank'] * 100:.0f}%)")
325325
print(f"{'=' * 80}")
326326

327327
# Configure OSF for this task
@@ -368,7 +368,7 @@ def train_with_osf(
368368

369369
# Unload OSF to get the updated base model for next task (if not last task)
370370
if task_idx < len(tasks) - 1:
371-
print(f"\nUnloading OSF adapter to prepare for next task...")
371+
print("\nUnloading OSF adapter to prepare for next task...")
372372
model = model.unload()
373373

374374
# Save final model
@@ -547,14 +547,16 @@ def print_results_comparison(osf_history, full_history):
547547
osf_avg_final = sum(osf_final_accs) / len(osf_final_accs)
548548
full_avg_final = sum(full_final_accs) / len(full_final_accs)
549549

550-
print(f"\n1. Average Accuracy Across All 3 Tasks (After Final Task):")
550+
print("\n1. Average Accuracy Across All 3 Tasks (After Final Task):")
551551
print(f" OSF: {osf_avg_final:.2f}%")
552552
print(f" Full FT: {full_avg_final:.2f}%")
553-
print(f" Difference: {osf_avg_final - full_avg_final:+.2f}% {'(OSF better)' if osf_avg_final > full_avg_final else '(Full FT better)'}")
553+
print(
554+
f" Difference: {osf_avg_final - full_avg_final:+.2f}% {'(OSF better)' if osf_avg_final > full_avg_final else '(Full FT better)'}"
555+
)
554556

555557
# Average forgetting (for tasks 1 and 2 only, since task 3 is the final task)
556-
print(f"\n2. Average Forgetting (Task 1 & 2):")
557-
print(f" Forgetting = Final Accuracy - Initial Accuracy (negative is worse)\n")
558+
print("\n2. Average Forgetting (Task 1 & 2):")
559+
print(" Forgetting = Final Accuracy - Initial Accuracy (negative is worse)\n")
558560

559561
osf_forgetting_vals = []
560562
full_forgetting_vals = []
@@ -583,7 +585,7 @@ def print_results_comparison(osf_history, full_history):
583585
osf_avg_forgetting = sum(osf_forgetting_vals) / len(osf_forgetting_vals)
584586
full_avg_forgetting = sum(full_forgetting_vals) / len(full_forgetting_vals)
585587

586-
print(f" Average Forgetting:")
588+
print(" Average Forgetting:")
587589
print(f" OSF: {osf_avg_forgetting:+.2f}%")
588590
print(f" Full FT: {full_avg_forgetting:+.2f}%")
589591
print(
@@ -673,7 +675,7 @@ def main():
673675
print(f" {task}: {acc:.2f}%")
674676

675677
# Average forgetting
676-
print(f"\n2. Average Forgetting (Task 1 & 2):")
678+
print("\n2. Average Forgetting (Task 1 & 2):")
677679
osf_forgetting_vals = []
678680
for task_idx, task in enumerate(tasks[:-1]):
679681
osf_initial_acc = osf_history[task][0][1] * 100

examples/orthogonal_subspace_learning/utils.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import random
1615

1716
import torch
1817
from datasets import load_dataset
19-
from transformers import AutoTokenizer
2018

2119

2220
def load_scienceqa(num_train=1000, num_eval=200, seed=42):
@@ -54,6 +52,7 @@ def load_numglue(num_train=1000, num_eval=200, seed=42):
5452
train_dataset, eval_dataset
5553
"""
5654
import json
55+
5756
from datasets import Dataset
5857
from huggingface_hub import hf_hub_download
5958

@@ -62,23 +61,20 @@ def load_numglue(num_train=1000, num_eval=200, seed=42):
6261

6362
# Read and process the JSON file line by line
6463
data = []
65-
with open(json_path, 'r') as f:
64+
with open(json_path) as f:
6665
for line in f:
6766
if line.strip(): # Skip empty lines
6867
item = json.loads(line)
6968
# Extract the number from the answer JSON structure
70-
answer = item.get('answer', '')
69+
answer = item.get("answer", "")
7170
if isinstance(answer, dict):
7271
# NumGLUE answers are JSON with 'number' and 'date' fields
7372
# Extract just the number field
74-
answer_str = answer.get('number', '')
73+
answer_str = answer.get("number", "")
7574
else:
7675
answer_str = str(answer)
7776

78-
data.append({
79-
'question': item.get('question', ''),
80-
'answer': answer_str
81-
})
77+
data.append({"question": item.get("question", ""), "answer": answer_str})
8278

8379
# Create dataset from processed data
8480
dataset = Dataset.from_list(data)
@@ -131,7 +127,7 @@ def format_scienceqa_for_llama(examples, tokenizer, max_length=512):
131127
choices = examples["choices"][i]
132128

133129
# Format choices
134-
choices_text = "\n".join([f"{chr(65+j)}. {choice}" for j, choice in enumerate(choices)])
130+
choices_text = "\n".join([f"{chr(65 + j)}. {choice}" for j, choice in enumerate(choices)])
135131

136132
prompt = f"""Answer the following science question by selecting the correct option.
137133
Question: {question}
@@ -279,7 +275,7 @@ def __init__(self, tokenizer, max_length=512):
279275

280276
def __call__(self, features):
281277
# Pad sequences
282-
max_len = min(max([len(f["input_ids"]) for f in features]), self.max_length)
278+
max_len = min(max(len(f["input_ids"]) for f in features), self.max_length)
283279

284280
input_ids = []
285281
attention_mask = []

0 commit comments

Comments
 (0)