Skip to content

Commit dba614c

Browse files
committed
fix
1 parent 3b115d2 commit dba614c

File tree

2 files changed

+73
-18
lines changed

2 files changed

+73
-18
lines changed

examples/mlx_finetuning_optimization/demo.py

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,29 +58,73 @@ def run_baseline(num_samples: int = 200, output_dir: str = "./demo_baseline"):
5858
return results
5959

6060

61+
def check_best_program_exists():
62+
"""Check if best_program.py exists and exit if not found"""
63+
# Check current directory first
64+
current_dir_best = os.path.join(os.getcwd(), "best_program.py")
65+
if os.path.exists(current_dir_best):
66+
print(f"✅ Found best_program.py in current directory: {current_dir_best}")
67+
return current_dir_best
68+
69+
# Check openevolve output directory
70+
script_dir = os.path.dirname(__file__)
71+
openevolve_output = os.path.join(script_dir, "openevolve_output")
72+
73+
if os.path.exists(openevolve_output):
74+
# Look for the best program
75+
best_dir = os.path.join(openevolve_output, "best")
76+
if os.path.exists(best_dir):
77+
best_program = os.path.join(best_dir, "best_program.py")
78+
if os.path.exists(best_program):
79+
print(f"✅ Found best_program.py in openevolve output: {best_program}")
80+
return best_program
81+
82+
# Look in checkpoints for latest
83+
checkpoints_dir = os.path.join(openevolve_output, "checkpoints")
84+
if os.path.exists(checkpoints_dir):
85+
checkpoints = [d for d in os.listdir(checkpoints_dir) if d.startswith("checkpoint_")]
86+
if checkpoints:
87+
latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("_")[1]))
88+
checkpoint_program = os.path.join(checkpoints_dir, latest_checkpoint, "best_program.py")
89+
if os.path.exists(checkpoint_program):
90+
print(f"✅ Found best_program.py in latest checkpoint: {checkpoint_program}")
91+
return checkpoint_program
92+
93+
# If we get here, no best_program.py was found
94+
print("❌ Error: best_program.py not found!")
95+
print("")
96+
print("The demo requires a best_program.py file with evolved optimizations.")
97+
print("")
98+
print("To get best_program.py, you can:")
99+
print(" 1. Run evolution: python demo.py --evolve --iterations 50")
100+
print(" 2. Copy from openevolve_output/best/ if it exists")
101+
print(" 3. Copy from a checkpoint: openevolve_output/checkpoints/checkpoint_*/best_program.py")
102+
print("")
103+
print("Searched locations:")
104+
print(f" • Current directory: {current_dir_best}")
105+
print(f" • OpenEvolve output: {os.path.join(script_dir, 'openevolve_output', 'best', 'best_program.py')}")
106+
print(f" • Latest checkpoint: {os.path.join(script_dir, 'openevolve_output', 'checkpoints', '*', 'best_program.py')}")
107+
print("")
108+
sys.exit(1)
109+
110+
61111
def run_optimized(num_samples: int = 200, output_dir: str = "./demo_optimized"):
62112
"""Run optimized MLX fine-tuning"""
63113
print("⚡ Running Optimized MLX Fine-tuning")
64114
print("=" * 50)
65115

116+
# Check that best_program.py exists before proceeding
117+
best_program_path = check_best_program_exists()
118+
66119
try:
67-
# Create trainer with automatic optimization loading
68-
trainer = create_optimized_trainer("mlx-community/Qwen3-0.6B-bf16")
120+
# Create trainer with specific optimization path
121+
trainer = create_optimized_trainer("mlx-community/Qwen3-0.6B-bf16", best_program_path)
69122
trainer.config.batch_size = 2
70123
trainer.config.num_epochs = 1
124+
print(f"✅ Created optimized trainer using {best_program_path}")
71125
except Exception as e:
72-
print(f"⚠️ Failed to create optimized trainer: {e}")
73-
print("Falling back to baseline with default optimizations...")
74-
trainer = BaselineTrainer("mlx-community/Qwen3-0.6B-bf16")
75-
trainer.config.batch_size = 2
76-
trainer.config.num_epochs = 1
77-
# Try to apply any available optimizations
78-
try:
79-
apply_optimizations(trainer)
80-
print("✅ Applied optimizations to baseline trainer")
81-
except Exception as opt_error:
82-
print(f"⚠️ Could not apply optimizations: {opt_error}")
83-
print("Using baseline trainer without optimizations")
126+
print(f"❌ Failed to create optimized trainer: {e}")
127+
sys.exit(1)
84128

85129
print(f"Creating {num_samples} training samples...")
86130
dataset = trainer.create_sample_dataset(num_samples)
@@ -105,10 +149,16 @@ def compare_performance(num_samples: int = 200):
105149
print("🏁 Comparing Baseline vs Optimized Performance")
106150
print("=" * 50)
107151

152+
# Check that best_program.py exists before proceeding
153+
best_program_path = check_best_program_exists()
154+
108155
print("Running comprehensive benchmark...")
156+
# Pass the specific best program path to ensure we use the evolved optimizations
157+
from mlx_optimization_patch import benchmark_optimization_improvement
109158
results = benchmark_optimization_improvement(
110159
model_name="mlx-community/Qwen3-0.6B-bf16",
111-
num_samples=num_samples
160+
num_samples=num_samples,
161+
optimization_path=best_program_path
112162
)
113163

114164
baseline = results["baseline"]
@@ -213,6 +263,9 @@ def demo_context_manager():
213263
print("🎭 Demonstrating Context Manager Usage")
214264
print("=" * 50)
215265

266+
# Check that best_program.py exists before proceeding
267+
best_program_path = check_best_program_exists()
268+
216269
# Example of how users would integrate into existing code
217270
trainer = BaselineTrainer("mlx-community/Qwen3-0.6B-bf16")
218271
trainer.config.batch_size = 1
@@ -222,7 +275,7 @@ def demo_context_manager():
222275

223276
print("Training with automatic optimizations...")
224277

225-
with mlx_optimizations():
278+
with mlx_optimizations(best_program_path):
226279
# All training inside this context will use optimized patterns
227280
results = trainer.train(dataset, "./demo_context_output")
228281

examples/mlx_finetuning_optimization/mlx_optimization_patch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,15 @@ def create_optimized_trainer(model_name: str = "mlx-community/Qwen3-0.6B-bf16",
244244

245245

246246
def benchmark_optimization_improvement(model_name: str = "mlx-community/Qwen3-0.6B-bf16",
247-
num_samples: int = 100) -> Dict[str, Any]:
247+
num_samples: int = 100,
248+
optimization_path: Optional[str] = None) -> Dict[str, Any]:
248249
"""
249250
Benchmark the improvement from evolved optimizations
250251
251252
Args:
252253
model_name: Model to benchmark
253254
num_samples: Number of training samples
255+
optimization_path: Path to optimization patterns (None for auto-detection)
254256
255257
Returns:
256258
Benchmark results comparing baseline vs optimized
@@ -264,7 +266,7 @@ def benchmark_optimization_improvement(model_name: str = "mlx-community/Qwen3-0.
264266
baseline_results = baseline_trainer.train(baseline_dataset, "./benchmark_baseline")
265267

266268
print("Benchmarking optimized trainer...")
267-
optimized_trainer = create_optimized_trainer(model_name)
269+
optimized_trainer = create_optimized_trainer(model_name, optimization_path)
268270
optimized_trainer.config.batch_size = 2
269271
optimized_dataset = optimized_trainer.create_sample_dataset(num_samples)
270272
optimized_results = optimized_trainer.train(optimized_dataset, "./benchmark_optimized")

0 commit comments

Comments
 (0)