Skip to content

Commit 52ba582

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 9461c1f commit 52ba582

File tree

2 files changed

+39
-18
lines changed

2 files changed

+39
-18
lines changed

scripts/bug_gen_modal.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ async def acquire(self):
523523
image=generator_image,
524524
secrets=[
525525
modal.Secret.from_name("GITHUB_TOKEN"),
526-
modal.Secret.from_name("PORTKEY_API_KEY")
526+
modal.Secret.from_name("PORTKEY_API_KEY"),
527527
],
528528
timeout=MODAL_TIMEOUT,
529529
volumes={LOGS_MOUNT_PATH: logs_volume}, # Mount volume for direct writes
@@ -1885,11 +1885,11 @@ def issue_gen_remote(
18851885
# Set up paths
18861886
volume_root = Path(LOGS_MOUNT_PATH) / language
18871887
task_insts_dir = volume_root / "task_insts"
1888-
1888+
18891889
# Resolve task instances file (it may have a hash suffix like repo__name.abcdef.json)
18901890
task_insts_file = None
18911891
repo_sanitized = repo.replace("/", "__")
1892-
1892+
18931893
if task_insts_dir.exists():
18941894
for filename in os.listdir(task_insts_dir):
18951895
# Check for exact match or match with suffix
@@ -1961,6 +1961,7 @@ def issue_gen_remote(
19611961
instances_processed = 0
19621962
if issue_gen_file.exists():
19631963
import json
1964+
19641965
with open(issue_gen_file) as f:
19651966
data = json.load(f)
19661967
instances_processed = len(data)
@@ -1973,6 +1974,7 @@ def issue_gen_remote(
19731974

19741975
except Exception as e:
19751976
import traceback
1977+
19761978
return {
19771979
"success": False,
19781980
"repo": repo,
@@ -1996,9 +1998,9 @@ async def run_issue_gen_phase_async(
19961998
issue_gen_workers: Number of workers per repo
19971999
issue_gen_redo: Whether to regenerate existing issues
19982000
"""
1999-
print(f"\n{'='*80}")
2001+
print(f"\n{'=' * 80}")
20002002
print(f"ISSUE GENERATION PHASE")
2001-
print(f"{'='*80}")
2003+
print(f"{'=' * 80}")
20022004
print(f"Processing {len(repos)} repositories...")
20032005
print(f"Config: {issue_gen_config}")
20042006
print(f"Workers per repo: {issue_gen_workers}")
@@ -2015,14 +2017,15 @@ async def run_issue_gen_phase_async(
20152017
},
20162018
order_outputs=False,
20172019
):
2018-
20192020
results.append(result)
20202021

20212022
# Print progress
20222023
completed = len(results)
20232024
if result["success"]:
20242025
instances = result.get("instances_processed", 0)
2025-
print(f" [{completed}/{len(repos)}] {result['repo']}: ✓ ({instances} instances)")
2026+
print(
2027+
f" [{completed}/{len(repos)}] {result['repo']}: ✓ ({instances} instances)"
2028+
)
20262029
else:
20272030
error = result.get("error", "Unknown error")
20282031
print(f" [{completed}/{len(repos)}] {result['repo']}: ✗ {error}")
@@ -2031,9 +2034,13 @@ async def run_issue_gen_phase_async(
20312034

20322035
# Summary
20332036
success = sum(1 for r in results if r["success"])
2034-
total_instances = sum(r.get("instances_processed", 0) for r in results if r["success"])
2037+
total_instances = sum(
2038+
r.get("instances_processed", 0) for r in results if r["success"]
2039+
)
20352040

2036-
print(f"\nIssue generation complete: {success}/{len(repos)} repos processed successfully.")
2041+
print(
2042+
f"\nIssue generation complete: {success}/{len(repos)} repos processed successfully."
2043+
)
20372044
print(f"Total instances with issues: {total_instances}\n")
20382045

20392046

swesmith/issue_gen/generate.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
7070
logger = logging.getLogger(__name__)
7171

72+
7273
class PortkeyModelConfig(BaseModel):
7374
model_name: str
7475
model_kwargs: dict[str, Any] = {}
@@ -83,7 +84,7 @@ def __init__(self, *, config_class: type = PortkeyModelConfig, **kwargs):
8384
raise ImportError(
8485
"The portkey-ai package is required to use PortkeyModel. Please install it with: pip install portkey-ai"
8586
)
86-
87+
8788
self.config = config_class(**kwargs)
8889
self.cost = 0.0
8990
self.n_calls = 0
@@ -124,7 +125,10 @@ def _query(self, messages: list[dict[str, str]], **kwargs):
124125

125126
def query(self, messages: list[dict[str, str]], **kwargs) -> Any:
126127
# Simple adapter to match what generate.py expects (return an object with choices and usage for cost)
127-
response = self._query([{"role": msg["role"], "content": msg["content"]} for msg in messages], **kwargs)
128+
response = self._query(
129+
[{"role": msg["role"], "content": msg["content"]} for msg in messages],
130+
**kwargs,
131+
)
128132
return response
129133

130134

@@ -167,12 +171,17 @@ def __init__(
167171

168172
# Initialize Portkey model if needed
169173
self.portkey_model = None
170-
if self.model.startswith("portkey/") or self.config.get("provider") == "portkey":
174+
if (
175+
self.model.startswith("portkey/")
176+
or self.config.get("provider") == "portkey"
177+
):
171178
self.portkey_model = PortkeyModel(
172179
model_name=self.model.replace("portkey/", ""),
173180
provider=self.config.get("provider", "openai"),
174-
litellm_model_name_override=self.config.get("litellm_model_name_override", ""),
175-
**settings.get("portkey_kwargs", {})
181+
litellm_model_name_override=self.config.get(
182+
"litellm_model_name_override", ""
183+
),
184+
**settings.get("portkey_kwargs", {}),
176185
)
177186

178187
data_smith = [x for x in load_dataset(HF_DATASET, split="train")]
@@ -365,18 +374,23 @@ def jinja_shuffle(seq):
365374

366375
# Generate n_instructions completions containing problem statements
367376
if self.portkey_model:
368-
response = self.portkey_model.query(messages, n=self.n_instructions, stream=False)
377+
response = self.portkey_model.query(
378+
messages, n=self.n_instructions, stream=False
379+
)
369380
else:
370381
response = completion(
371-
model=self.model, messages=messages, n=self.n_instructions, temperature=0
382+
model=self.model,
383+
messages=messages,
384+
n=self.n_instructions,
385+
temperature=0,
372386
)
373387

374388
model_for_cost = self.model
375389
if self.portkey_model and self.portkey_model.config.litellm_model_name_override:
376390
model_for_cost = self.portkey_model.config.litellm_model_name_override
377-
391+
378392
cost = completion_cost(response, model=model_for_cost)
379-
393+
380394
metadata["cost"] = (0 if "cost" not in metadata else metadata["cost"]) + cost
381395

382396
# Extract problem statements from response

0 commit comments

Comments
 (0)