Skip to content

Commit 91f0efe

Browse files
authored
Merge pull request #128 from psychicmario/core50
added choice of object/category level detection
2 parents 22f60d3 + 7d5b367 commit 91f0efe

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

continuum/datasets/core50.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,16 @@ def __init__(
3131
train: bool = True,
3232
train_image_ids: Union[str, Iterable[str], None] = None,
3333
scenario: str = "classes",
34-
download: bool = True
34+
download: bool = True,
35+
classification: str = "object"
3536
):
3637
assert scenario in ["classes", "domains", "objects"]
38+
assert classification in ["object", "category"]
3739
self.train_image_ids = train_image_ids
3840
super().__init__(data_path=data_path, train=train, download=download)
3941

4042
self.scenario = scenario
43+
self.classification = classification
4144

4245
if self.train_image_ids is None:
4346
self.train_image_ids = os.path.join(self.data_path, "core50_train.csv")
@@ -126,7 +129,12 @@ def get_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
126129
continue
127130

128131
x.append(os.path.join(object_folder, path))
129-
class_label = object_id // 5
132+
if self.classification == "object":
133+
# Ranges from 0-49
134+
class_label = object_id
135+
else:
136+
# Ranges from 0-9
137+
class_label = object_id // 5
130138
y.append(class_label)
131139

132140
if self.train: # We add a new domain id for the train set.

0 commit comments

Comments
 (0)