Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cvpods/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _apply_transforms(self, image, annotations=None, **kwargs):
if isinstance(self.transforms, dict):
dataset_dict = {}
for key, tfms in self.transforms.items():
img = deepcopy(image)
img = np.copy(image)
annos = deepcopy(annotations)
for tfm in tfms:
img, annos = tfm(img, annos, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion cvpods/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def build_train_loader(cfg):
rank = comm.get_rank()

# use subdivision batchsize
images_per_minibatch = cfg.SOLVER.IMS_PER_DEVICE // cfg.SOLVER.BATCH_SUBDIVISIONS
images_per_minibatch = cfg.SOLVER.IMS_PER_DEVICE

logger = logging.getLogger(__name__)

Expand Down
13 changes: 9 additions & 4 deletions cvpods/data/datasets/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,15 @@ def process(dd, img, annos):
elif len(img.shape) == 3 and img.shape[-1] == 3:
dd["image"] = torch.as_tensor(
np.ascontiguousarray(img.transpose(2, 0, 1)))
elif len(img.shape) == 4 and img.shape[-1] == 3:
# NHWC -> NCHW
dd["image"] = torch.as_tensor(
np.ascontiguousarray(img.transpose(0, 3, 1, 2)))
elif len(img.shape) == 4:
if img.shape[-1] == 3:
# NHWC -> NCHW
dd["image"] = torch.as_tensor(
np.ascontiguousarray(img.transpose(0, 3, 1, 2)))
elif img.shape[1] == 3:
# NCHW
dd["image"] = torch.as_tensor(np.ascontiguousarray(img))

return dd

if isinstance(images, dict):
Expand Down
2 changes: 1 addition & 1 deletion cvpods/data/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def __call__(self, img, annotations=None, **kwargs):
return img, annotations

def __repr__(self):
return "".join([tfm for tfm in self.transforms])
return "".join([tfm.__repr__() for tfm in self.transforms])


# TODO: Deprecated
Expand Down
3 changes: 2 additions & 1 deletion cvpods/engine/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def __init__(self, accumulate_grad_steps=1, grad_clipper=None, mixed_precision=F
self.mixed_precision = mixed_precision

def before_step(self):
self.trainer.optimizer.zero_grad()
if self.trainer.iter % self.accumulate_grad_steps == 0:
self.trainer.optimizer.zero_grad()

def after_step(self):
losses = self.trainer.step_outputs["loss_for_backward"]
Expand Down
15 changes: 7 additions & 8 deletions cvpods/engine/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ def __init__(self, cfg, build_model):
self.logger = logger

self.data_loader = self.build_train_loader(cfg)
auto_scale_config(cfg, self.data_loader)

# Assume these objects must be constructed in this order.
model = build_model(cfg)
self.model = maybe_convert_module(model)
self.logger.info(f"Model: \n{self.model}")

# Assume these objects must be constructed in this order.
self.optimizer = self.build_optimizer(cfg, self.model)

# For training, wrap with DDP. But don't need this for inference.
Expand Down Expand Up @@ -117,13 +118,12 @@ def __init__(self, cfg, build_model):
)

if not cfg.SOLVER.LR_SCHEDULER.get("EPOCH_WISE", False):
epoch_iters = -1
self.epoch_iters = -1
else:
epoch_iters = cfg.SOLVER.LR_SCHEDULER.get("EPOCH_ITERS")
self.logger.warning(f"Setup LR Scheduler in EPOCH mode: {epoch_iters}")
self.epoch_iters = cfg.SOLVER.LR_SCHEDULER.get("EPOCH_ITERS")
self.logger.warning(f"Setup LR Scheduler in EPOCH mode: {self.epoch_iters}")

auto_scale_config(cfg, self.data_loader)
self.scheduler = self.build_lr_scheduler(cfg, self.optimizer, epoch_iters=epoch_iters)
self.scheduler = self.build_lr_scheduler(cfg, self.optimizer, epoch_iters=self.epoch_iters)
# Assume no other objects need to be checkpointed.
# We can later make it checkpoint the stateful hooks
self.checkpointer = DefaultCheckpointer(
Expand Down Expand Up @@ -402,8 +402,7 @@ def auto_scale_config(cfg, dataloader):
Here we use batch size * subdivision to simulator large batch training
"""
if max_epoch:
epoch_iter = math.ceil(
len(dataloader.dataset) / (cfg.SOLVER.IMS_PER_BATCH * subdivision))
epoch_iter = math.ceil(len(dataloader.dataset) / cfg.SOLVER.IMS_PER_BATCH)

if max_iter is not None:
logger.warning(
Expand Down