Skip to content

Commit e8c5c31

Browse files
committed
Merge branch 'develop'
2 parents 372017f + a296c77 commit e8c5c31

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

framework3/plugins/optimizer/wandb_optimizer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,12 @@ def fit(self, x: XYData, y: XYData | None = None) -> None:
181181
)
182182

183183
if self.sweep_id is not None:
184-
WandbAgent()(
185-
self.sweep_id, self.project, lambda config: self.exec(config, x, y)
186-
)
184+
sweep = WandbSweepManager().get_sweep(self.project, self.sweep_id)
185+
sweep_state = sweep.state.lower()
186+
if sweep_state not in ("finished", "cancelled", "crashed"):
187+
WandbAgent()(
188+
self.sweep_id, self.project, lambda config: self.exec(config, x, y)
189+
)
187190
else:
188191
raise ValueError("Either pipeline or sweep_id must be provided")
189192

framework3/plugins/storage/s3_storage.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,22 +138,32 @@ def upload_file(
138138

139139
return file_name
140140

141-
def list_stored_files(self, context) -> List[Any]:
141+
def list_stored_files(self, context: str) -> List[str]:
142142
"""
143-
List all files in the S3 bucket.
143+
List all files in a specific folder (context) in the S3 bucket.
144144
145145
Args:
146-
context (str): Not used in this implementation.
146+
context (str): The folder path within the bucket to list files from.
147147
148148
Returns:
149-
List[str]: A list of object keys in the bucket.
149+
List[str]: A list of object keys in the specified folder.
150150
"""
151-
return list(
152-
map(
153-
lambda x: x["Key"],
154-
self._client.list_objects_v2(Bucket=self.bucket)["Contents"],
155-
)
156-
)
151+
# Ensure the context ends with a trailing slash if it's not empty
152+
prefix = f"{context}/" if context and not context.endswith("/") else context
153+
154+
paginator = self._client.get_paginator("list_objects_v2")
155+
pages = paginator.paginate(Bucket=self.bucket, Prefix=prefix)
156+
157+
file_list = []
158+
for page in pages:
159+
if "Contents" in page:
160+
for obj in page["Contents"]:
161+
# Remove the prefix from the key to get the relative path
162+
relative_path = obj["Key"][len(prefix) :]
163+
if relative_path: # Ignore the folder itself
164+
file_list.append(relative_path)
165+
166+
return file_list
157167

158168
def get_file_by_hashcode(self, hashcode: str, context: str) -> bytes:
159169
"""

0 commit comments

Comments
 (0)