Skip to content

Commit 569143c

Browse files
committed
Update evaluator.py
1 parent ca90538 commit 569143c

File tree

1 file changed

+21
-67
lines changed

1 file changed

+21
-67
lines changed

examples/mlx_metal_kernel_opt/evaluator.py

Lines changed: 21 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(self):
5353

5454
print("🔧 Initialized Fixed Custom GQA Evaluator")
5555
print(f"📱 Model: {self.model_path}")
56-
print(f"🧪 Using comprehensive test suite (20+ scenarios)")
56+
print(f"🧪 Using 5 representative tests for fast evolution")
5757
print(f"📊 Dynamic baseline measurement enabled")
5858

5959
def evaluate(self, program_text: str) -> Dict[str, Any]:
@@ -69,7 +69,7 @@ def evaluate(self, program_text: str) -> Dict[str, Any]:
6969
print("🔬 FIXED CUSTOM GQA ATTENTION EVALUATION")
7070
print("=" * 100)
7171
print("✅ Using dynamic baseline measurement")
72-
print("✅ Using comprehensive test coverage (20+ scenarios)")
72+
print("✅ Using 5 representative tests for fast evolution")
7373
print("✅ Using direct model testing (no subprocess)")
7474
print("✅ Using proper statistical methodology")
7575
print("=" * 100)
@@ -271,80 +271,34 @@ def _measure_baseline_performance(self) -> Optional[List[BenchmarkResult]]:
271271
return None
272272

273273
def _get_evolution_benchmark_configs(self) -> List[BenchmarkConfig]:
274-
"""Get representative benchmark configs for evolution (subset of full suite for speed)"""
274+
"""Get 5 most representative benchmark configs for faster evolution"""
275275

276276
# Get all comprehensive configs
277277
all_configs = self.benchmark_suite.create_benchmark_configs()
278278

279-
# Select representative subset across all categories for faster evolution
280-
# while maintaining comprehensive coverage
279+
# Select only 5 most representative tests across all categories
280+
# for significantly faster evolution while maintaining coverage
281281
representative_configs = []
282282

283-
# Context length variations (4 configs)
284-
context_configs = [c for c in all_configs if "context" in c.name]
285-
representative_configs.extend(context_configs) # All 4 context tests are important
286-
287-
# Generation length patterns (select key ones)
288-
generation_configs = [c for c in all_configs if "generation" in c.name]
289-
representative_configs.extend(
290-
[
291-
c
292-
for c in generation_configs
293-
if c.name
294-
in [
295-
"micro_generation",
296-
"short_generation",
297-
"long_generation",
298-
"very_long_generation",
299-
]
300-
]
301-
)
302-
303-
# Use case patterns (select most important)
304-
use_case_configs = [
305-
c
306-
for c in all_configs
307-
if any(
308-
x in c.name
309-
for x in ["code", "reasoning", "creative", "technical", "conversational"]
310-
)
311-
]
312-
representative_configs.extend(
313-
[
314-
c
315-
for c in use_case_configs
316-
if c.name
317-
in ["code_generation", "step_by_step_reasoning", "conversational_assistant"]
318-
]
319-
)
320-
321-
# Memory pressure (select key ones)
322-
memory_configs = [
323-
c for c in all_configs if any(x in c.name for x in ["progressive", "repetitive"])
283+
# Map of specific test names to select
284+
selected_test_names = [
285+
"short_context_quick", # Short context + quick response (chat scenario)
286+
"long_context_detailed", # Long context analysis (memory pressure)
287+
"long_generation", # Long generation (decode performance critical)
288+
"code_generation", # Code generation (structured output patterns)
289+
"maximum_context_stress_test" # Ultimate stress test (maximum challenge)
324290
]
325-
representative_configs.extend(
326-
[
327-
c
328-
for c in memory_configs
329-
if c.name in ["progressive_context_building", "repetitive_pattern_generation"]
330-
]
331-
)
332291

333-
# Extended tests (select 1-2 key ones)
334-
extended_configs = [
335-
c
336-
for c in all_configs
337-
if any(x in c.name for x in ["extreme", "sustained", "comprehensive", "maximum"])
338-
]
339-
representative_configs.extend(
340-
[
341-
c
342-
for c in extended_configs
343-
if c.name in ["extreme_long_generation", "maximum_context_stress_test"]
344-
]
345-
)
292+
# Find and add the selected tests
293+
config_dict = {c.name: c for c in all_configs}
294+
295+
for test_name in selected_test_names:
296+
if test_name in config_dict:
297+
representative_configs.append(config_dict[test_name])
298+
else:
299+
print(f" ⚠️ Warning: Test '{test_name}' not found in benchmark suite")
346300

347-
print(f" 📋 Selected {len(representative_configs)} representative benchmarks:")
301+
print(f" 📋 Selected {len(representative_configs)} representative benchmarks for fast evolution:")
348302
for config in representative_configs:
349303
print(f" • {config.name}: {config.description}")
350304

0 commit comments

Comments
 (0)