diff --git a/dlframeworks/chainer/communicators/kfac_communicators/__init__.py b/dlframeworks/chainer/communicators/kfac_communicators/__init__.py index ac2ca08..d6c52c3 100644 --- a/dlframeworks/chainer/communicators/kfac_communicators/__init__.py +++ b/dlframeworks/chainer/communicators/kfac_communicators/__init__.py @@ -1,5 +1,5 @@ def create_communicator( - communicator_name='flat', mpi_comm=None, dynamic=False, debug=False): + communicator_name='flat', mpi_comm=None, debug=False): if mpi_comm is None: import mpi4py.MPI @@ -8,8 +8,7 @@ def create_communicator( if communicator_name == 'flat': from dlframeworks.chainer.communicators.kfac_communicators\ .flat_communicator import FlatCommunicator - return FlatCommunicator( - mpi_comm=mpi_comm, dynamic=dynamic, debug=debug) + return FlatCommunicator(mpi_comm, debug) else: raise ValueError( 'Unrecognized communicator: "{}"'.format(communicator_name)) diff --git a/dlframeworks/chainer/communicators/kfac_communicators/_memory_utility.py b/dlframeworks/chainer/communicators/kfac_communicators/_memory_utility.py deleted file mode 100644 index 6b34ef2..0000000 --- a/dlframeworks/chainer/communicators/kfac_communicators/_memory_utility.py +++ /dev/null @@ -1,66 +0,0 @@ -from dlframeworks.chainer.utils import get_link - - -def reduce_scatterv_pack(dictionary, divided_linknames, gpu_buf, sizeof_dtype): - sendcounts = [] - displs = [] - buf_offset = 0 - sendcount_offset = 0 - for linknames in divided_linknames: - sendcount = 0 - for linkname in sorted(linknames): - arrays = dictionary[linkname] - for array in arrays: - sendcount += array.size - nbytes = array.size * sizeof_dtype - gpu_buf.from_device(array, nbytes, buf_offset) - buf_offset += nbytes - sendcounts.append(sendcount) - displs.append(sendcount_offset) - sendcount_offset += sendcount - return sendcounts, displs - - -def reduce_scatterv_unpack(dictionary, linknames, gpu_buf, sizeof_dtype): - buf_offset = 0 - for linkname in sorted(linknames): - arrays = dictionary[linkname] - for array in arrays: - nbytes = array.size * sizeof_dtype - gpu_buf.to_device(array, nbytes, buf_offset) - buf_offset += nbytes - - -def allgatherv_pack(model, divided_linknames, gpu_buf, sizeof_dtype, rank): - sendcounts = [] - displs = [] - sendcount_offset = 0 - buf_offset = 0 - for i, linknames in enumerate(divided_linknames): - sendcount = 0 - for linkname in sorted(linknames): - link = get_link(model, linkname) - for paramname, param in sorted(link.namedparams()): - if param.kfgrad is None: - continue - sendcount += param.kfgrad.size - if i == rank: - nbytes = param.kfgrad.size * sizeof_dtype - gpu_buf.from_device(param.kfgrad, nbytes, buf_offset) - buf_offset += nbytes - sendcounts.append(sendcount) - displs.append(sendcount_offset) - sendcount_offset += sendcount - return sendcounts, displs - - -def allgatherv_unpack(model, linknames, gpu_buf, sizeof_dtype): - buf_offset = 0 - for linkname in linknames: - link = get_link(model, linkname) - for paramname, param in sorted(link.namedparams()): - if param.kfgrad is None: - continue - nbytes = param.kfgrad.size * sizeof_dtype - gpu_buf.to_device(param.kfgrad, nbytes, buf_offset) - buf_offset += nbytes diff --git a/dlframeworks/chainer/communicators/kfac_communicators/_utility.py b/dlframeworks/chainer/communicators/kfac_communicators/_utility.py new file mode 100644 index 0000000..713c223 --- /dev/null +++ b/dlframeworks/chainer/communicators/kfac_communicators/_utility.py @@ -0,0 +1,99 @@ +import itertools + +from dlframeworks.chainer.utils import create_mpi_print + + +def extract(fisher_blocks, indices, extractors): + arrays = [] + for local_indices in indices: + if len(local_indices) == 0: + arrays.append([]) + else: + local_arrays = [] + for index in local_indices: + for extractor in extractors: + for array in extractor(fisher_blocks[index]): + local_arrays.append(array) + arrays.append(local_arrays) + return arrays + + +def extract_cov_emas(fisher_block): + ret = [] + if fisher_block.cov_emas is not None: + for cov_ema in fisher_block.cov_emas: + ret.append(cov_ema) + return ret + + +def extract_grads(fisher_block): + ret = [] + for _, param in sorted(fisher_block.link.namedparams()): + if param.grad is not None: + ret.append(param.grad) + return ret + + +def extract_kfgrads(fisher_block): + ret = [] + for _, param in sorted(fisher_block.link.namedparams()): + if hasattr(param, 'kfgrad') and param.kfgrad is not None: + ret.append(param.kfgrad) + return ret + + +def get_nelems(arrays): + nelems = 0 + for array in list(itertools.chain(*arrays)): # flatten arrays + nelems += array.size + return nelems + + +def get_sendcounts_and_displs(arrays): + sendcounts = [] + displs = [] + sendcount_offset = 0 + for local_arrays in arrays: + sendcount = 0 + for array in local_arrays: + sendcount += array.size + sendcounts.append(sendcount) + displs.append(sendcount_offset) + sendcount_offset += sendcount + return sendcounts, displs + + +def pack(arrays, gpu_buf, sizeof_dtype): + buf_offset = 0 + for array in list(itertools.chain(*arrays)): # flatten arrays + nbytes = array.size * sizeof_dtype + gpu_buf.from_device(array, nbytes, buf_offset) + buf_offset += nbytes + + +def unpack(arrays, gpu_buf, sizeof_dtype): + buf_offset = 0 + for array in list(itertools.chain(*arrays)): # flatten arrays + nbytes = array.size * sizeof_dtype + gpu_buf.to_device(array, nbytes, buf_offset) + buf_offset += nbytes + + +def allocate_kfgrads(fisher_blocks): + for fisher_block in fisher_blocks: + for _, param in sorted(fisher_block.link.namedparams()): + if param.grad is None: + continue + if not hasattr(param, 'kfgrad'): + kfgrad = param.grad.copy() + kfgrad.fill(0.) + setattr(param, 'kfgrad', kfgrad) + + +def print_debug_message(mpi_comm, arrays, prefix): + mpi_print = create_mpi_print(mpi_comm) + idx = 0 + for array in list(itertools.chain(*arrays)): # flatten arrays + mpi_print('{} {} MEAN {}'.format( + prefix, idx, array.mean())) + idx += 1 diff --git a/dlframeworks/chainer/communicators/kfac_communicators/flat_communicator.py b/dlframeworks/chainer/communicators/kfac_communicators/flat_communicator.py index 8fcf58c..9b9448c 100644 --- a/dlframeworks/chainer/communicators/kfac_communicators/flat_communicator.py +++ b/dlframeworks/chainer/communicators/kfac_communicators/flat_communicator.py @@ -5,14 +5,13 @@ from dlframeworks.chainer.communicators.kfac_communicators \ import kfac_communicator_base from dlframeworks.chainer.communicators.kfac_communicators \ - import _memory_utility + import _utility class FlatCommunicator(kfac_communicator_base.KFACCommunicatorBase): - def __init__(self, mpi_comm, dynamic=False, debug=False): - super(FlatCommunicator, self).__init__( - mpi_comm, False, dynamic, debug) + def __init__(self, mpi_comm, debug=False): + super(FlatCommunicator, self).__init__(mpi_comm, debug) # GPU buffers self.gpu_buffer_a = DeviceMemory() @@ -22,38 +21,42 @@ def __init__(self, mpi_comm, dynamic=False, debug=False): self.mpi_dtype = mpi4py.MPI.FLOAT self.sizeof_dtype = 4 - def reduce_scatterv(self, model, covs, root=0): + def reduce_scatterv_grad(self, fisher_blocks, root=0): """Reduce and Scatterv grads and covs - 1. Extract (by reference) - model, covs -> dictionary + 1. Extract + grads, cov_emas -> arrays 2. Pack - dictionary -> GPU buffer A + arrays -> GPU buffer A 3. Reduce GPU buffer A -> GPU buffer B 4. Scatterv GPU buffer B -> GPU buffer A 5. Unpack - GPU buffer A -> dictionary + GPU buffer A -> arrays """ - self.setup(model) + self.setup(fisher_blocks) cuda_stream = chainer.cuda.Stream.null - dictionary = self.reduce_scatterv_extract(model, covs) - nelems = self.reduce_scatterv_get_nelems(dictionary) + # We extract cov_emas and grads from fisher_blocks + extractors = [_utility.extract_cov_emas, _utility.extract_grads] + arrays = _utility.extract(fisher_blocks, self.indices, extractors) + + # Get total number of elements + nelems = _utility.get_nelems(arrays) nbytes = nelems * self.sizeof_dtype self.gpu_buffer_a.assign(nbytes) self.gpu_buffer_b.assign(nbytes) - # Pack the elements in a single buffer, calculate sendcounts, and - # calculate displs + # Calculate sendcounts, and calculate displs # - sendcounts: the number of elements to send to each process # - displs: the displacements where each segment begins - sendcounts, displs = _memory_utility.reduce_scatterv_pack( - dictionary, self.divided_linknames, self.gpu_buffer_a, - self.sizeof_dtype) + sendcounts, displs = _utility.get_sendcounts_and_displs(arrays) + + # Pack the elements in a single buffer + _utility.pack(arrays, self.gpu_buffer_a, self.sizeof_dtype) # Buffers for Reduce sendbuf = [self.gpu_buffer_a.buffer(nbytes), self.mpi_dtype] @@ -61,59 +64,70 @@ def reduce_scatterv(self, model, covs, root=0): self.rank == root else None if self.debug: - self.reduce_scatterv_debug(dictionary, 'BEFORE') + _utility.print_debug_message(self.mpi_comm, arrays, + 'BEFORE REDUCE_SCATTERV') # We must sync before communication cuda_stream.synchronize() self.mpi_comm.Reduce(sendbuf, recvbuf, root=root) - if not self.is_worker: + if not self.is_inv_worker: return # Buffers for Scatterv - nbytes_local = sendcounts[self.invcomm.rank] * self.sizeof_dtype + nbytes_local = sendcounts[self.inv_comm.rank] * self.sizeof_dtype sendbuf = [self.gpu_buffer_b.buffer(nbytes), sendcounts, displs, self.mpi_dtype] if self.rank == root else None recvbuf = self.gpu_buffer_a.buffer(nbytes_local) # We must sync before communication cuda_stream.synchronize() - self.invcomm.mpi_comm.Scatterv(sendbuf, recvbuf, root=root) + self.inv_comm.mpi_comm.Scatterv(sendbuf, recvbuf, root=root) # Unpack the all elements - _memory_utility.reduce_scatterv_unpack( - dictionary, self.divided_linknames[self.invcomm.rank], - self.gpu_buffer_a, self.sizeof_dtype) + _utility.unpack(arrays[self.inv_comm.rank], self.gpu_buffer_a, + self.sizeof_dtype) if self.debug: - self.reduce_scatterv_debug(dictionary, 'AFTER') + _utility.print_debug_message(self.mpi_comm, arrays, + 'AFTER REDUCE_SCATTERV') - def allgatherv(self, model): + def allgatherv_kfgrad(self, fisher_blocks): """Allgatherv kfgrads + 1. Extract + kfgrads -> arrays 1. Pack - kfgrads -> GPU buffer A + arrays -> GPU buffer A 2. Allgatherv GPU buffer A -> GPU buffer B 3. Unpack - GPU buffer B -> kfgrads + GPU buffer B -> arrays """ + # Allocate memory space for recieving kfgrads + _utility.allocate_kfgrads(fisher_blocks) + cuda_stream = chainer.cuda.Stream.null - nelems = self.allgatherv_get_nelems(model) + # We extract kfgrads from fisher_blocks + extractors = [_utility.extract_kfgrads] + arrays = _utility.extract(fisher_blocks, self.indices, extractors) + + # Get total number of elements + nelems = _utility.get_nelems(arrays) nbytes = nelems * self.sizeof_dtype self.gpu_buffer_a.assign(nbytes) self.gpu_buffer_b.assign(nbytes) - # Pack the elements in a single buffer, calculate sendcounts, and - # calculate displs + # Calculate sendcounts, and calculate displs # - sendcounts: the number of elements to send to each process # - displs: the displacements where each segment begins - sendcounts, displs = _memory_utility.allgatherv_pack( - model, self.divided_linknames, self.gpu_buffer_a, - self.sizeof_dtype, self.rank) + sendcounts, displs = _utility.get_sendcounts_and_displs(arrays) + + # Pack the elements in a single buffer + _utility.pack(arrays[self.rank], self.gpu_buffer_a, self.sizeof_dtype) # Buffers for Allgatherv nbytes_local = sendcounts[self.rank] * self.sizeof_dtype @@ -122,15 +136,16 @@ def allgatherv(self, model): self.mpi_dtype] if self.debug: - self.allgatherv_debug(model, 'BEFORE') + _utility.print_debug_message(self.mpi_comm, arrays, + 'BEFORE ALLGATHERV') # We must sync before communication cuda_stream.synchronize() self.mpi_comm.Allgatherv(sendbuf, recvbuf) # Unpack the all elements - _memory_utility.allgatherv_unpack( - model, self.linknames, self.gpu_buffer_b, self.sizeof_dtype) + _utility.unpack(arrays, self.gpu_buffer_b, self.sizeof_dtype) if self.debug: - self.allgatherv_debug(model, 'AFTER') + _utility.print_debug_message(self.mpi_comm, arrays, + 'AFTER ALLGATHERV') diff --git a/dlframeworks/chainer/communicators/kfac_communicators/kfac_communicator_base.py b/dlframeworks/chainer/communicators/kfac_communicators/kfac_communicator_base.py index d852c32..c057cad 100644 --- a/dlframeworks/chainer/communicators/kfac_communicators/kfac_communicator_base.py +++ b/dlframeworks/chainer/communicators/kfac_communicators/kfac_communicator_base.py @@ -1,89 +1,38 @@ from chainermn.communicators import mpi_communicator_base import numpy as np -from dlframeworks.chainer.utils import create_mpi_print -from dlframeworks.chainer.utils import get_link -from dlframeworks.chainer.utils import get_linknames - class KFACCommunicatorBase(mpi_communicator_base.MpiCommunicatorBase): - def __init__(self, mpi_comm, use_nccl=False, dynamic=False, debug=False): - super(KFACCommunicatorBase, self).__init__(mpi_comm, use_nccl) - self.dynamic = dynamic + def __init__(self, mpi_comm, debug=False): + super(KFACCommunicatorBase, self).__init__(mpi_comm) self.debug = debug self.is_setup = False - def allreduce_grad(self, *args, **kwargs): - pass - - def setup(self, model): - if self.is_setup and not self.dynamic: + def setup(self, fisher_blocks): + if self.is_setup: return - rank = self.rank - size = self.size - linknames = sorted(get_linknames(model)) - is_worker = True if rank < len(linknames) else False - invcomm = self - if size > len(linknames): - invcomm = self.split(int(is_worker), rank) - divided_linknames = np.array_split(linknames, self.size) + n = len(fisher_blocks) + is_inv_worker = True if self.rank < n else False + if self.size > n: + inv_comm = self.split(int(is_inv_worker), self.rank) + else: + inv_comm = self - self.linknames = linknames - self.is_worker = is_worker - self.invcomm = invcomm - self.divided_linknames = divided_linknames - self.is_setup = True - - def reduce_scatterv(self, model, covs, root=0): - raise NotImplementedError + indices = np.array_split(np.arange(n), self.size) + indices = [local_indices.tolist() for local_indices in indices] - def reduce_scatterv_extract(self, model, covs): - linknames = sorted(get_linknames(model)) - ret = {} - for linkname in linknames: - ret[linkname] = [] - if linkname in covs.keys(): - for cov in covs[linkname]: - ret[linkname].append(cov) - link = get_link(model, linkname) - for _, param in sorted(link.namedparams()): - if param.grad is not None: - ret[linkname].append(param.grad) - return ret - - def reduce_scatterv_get_nelems(self, dictionary): - nelems = 0 - for _, arrays in sorted(dictionary.items()): - for array in arrays: - nelems += array.size - return nelems + self.is_inv_worker = is_inv_worker + self.inv_comm = inv_comm + self.indices = indices + self.is_setup = True - def reduce_scatterv_debug(self, dictionary, prefix): - mpi_print = create_mpi_print(self.mpi_comm) - idx = 0 - for linkname, arrays in sorted(dictionary.items()): - for array in arrays: - mpi_print('{} REDUCE_SCATTERV IDX {} MEAN {}' - .format(prefix, idx, array.mean())) - idx += 1 + def allreduce_grad(self, *args, **kwargs): + pass - def allgatherv(self, model): + def reduce_scatterv_grad(self, fisher_blocks, root=0): raise NotImplementedError - def allgatherv_get_nelems(self, model): - nelems = 0 - for _, param in sorted(model.namedparams()): - if param.kfgrad is None: - continue - nelems += param.kfgrad.size - return nelems - - def allgatherv_debug(self, model, prefix): - mpi_print = create_mpi_print(self.mpi_comm) - idx = 0 - for _, param in sorted(model.namedparams()): - mpi_print('{} ALLGATHERV IDX {} KFGRAD_MEAN {}'.format( - prefix, idx, param.kfgrad.mean())) - idx += 1 + def allgatherv_kfgrad(self, fisher_blocks): + raise NotImplementedError diff --git a/dlframeworks/chainer/optimizers/kfac.py b/dlframeworks/chainer/optimizers/kfac.py index cd5a7d4..9d4b8cb 100644 --- a/dlframeworks/chainer/optimizers/kfac.py +++ b/dlframeworks/chainer/optimizers/kfac.py @@ -1,11 +1,12 @@ import collections -import numpy import chainer from chainer import optimizer from chainer.backends import cuda from chainer.functions import im2col +from dlframeworks.chainer.utils import get_divided_linknames + _default_hyperparam = chainer.optimizer.Hyperparameter() _default_hyperparam.lr = 0.001 _default_hyperparam.momentum = 0.9 @@ -196,32 +197,15 @@ def create_update_rule(self): def update(self, lossfun=None, *args, **kwds): comm = self.communicator - if comm is None: - self.grad_update(lossfun, *args, **kwds) - self.cov_ema_update() - if self.t % self.hyperparam.inv_freq == 0 and self.t > 0: - self.inv_update() - else: - if comm.is_grad_worker: - if comm.gcomm.rank == 0: - print('grad_update()') - self.grad_update(lossfun, *args, **kwds) - if comm.is_cov_worker: - if comm.ccomm.rank == 0: - print('cov_ema_update()') - self.is_done = self.cov_ema_update() - if comm.is_inv_worker: - if comm.icomm_g.rank == 0: - print('inv_update()') - self.is_done = self.inv_update() + self.grad_update(lossfun, *args, **kwds) + self.cov_ema_update() + if comm is not None: + comm.reduce_scatterv(self.target, self.cov_ema_dict) + self.inv_update() + if comm is not None: + comm.allgatherv(self.target) def grad_update(self, lossfun=None, *args, **kwds): - comm = self.communicator - # ======== Communication - if comm is not None: - if self.t % self.hyperparam.inv_freq == 1: - comm.sendrecv_param(self) - self.t_cov += 1 if lossfun is not None: use_cleargrads = getattr(self, '_use_cleargrads', True) loss = lossfun(*args, **kwds) @@ -238,17 +222,6 @@ def grad_update(self, lossfun=None, *args, **kwds): del loss # No more backward computation, free memory - # ======== Communication - if comm is not None: - synced = comm.allreduce_grad(self) - if not synced: - return - if self.t % self.hyperparam.inv_freq == 0 and self.t > 0: - if self.t_inv == 0 and not comm.is_inv_worker: - self.inv_dict = self.allocate_matrices() - comm.bcast_inv(self.inv_dict) - self.t_inv += 1 - # Update param.kfgrad for each layer self.kfac_grad_update() @@ -355,57 +328,14 @@ def get_link(self, path): return _link return None - def allocate_matrices(self): - dictionary = collections.OrderedDict() - for linkname in self.linknames: - link = self.get_link(linkname) - param_W = self.get_param(linkname + '/W') - param_b = self.get_param(linkname + '/b') - if param_W is None: - raise ValueError('param_W MUST not be None at', linkname) - xp = cuda.get_array_module(param_W.data) - - with cuda.get_device_from_array(param_W.data): - if isinstance(link, _linear_link): - n_out, n_in = param_W.shape - if param_b is not None: - A = xp.zeros((n_in + 1, n_in + 1)) - else: - A = xp.zeros((n_in, n_in)) - G = xp.zeros((n_out, n_out)) - elif isinstance(link, _convolution_2d_link): - c_out, c_in, kh, kw = param_W.shape - if param_b is not None: - A = xp.zeros((c_in*kh*kw + 1, c_in*kh*kw + 1)) - else: - A = xp.zeros((c_in*kh*kw, c_in*kh*kw)) - G = xp.zeros((c_out, c_out)) - else: - continue - dictionary[linkname] = [A, G] - return collections.OrderedDict( - sorted(dictionary.items(), key=lambda x: x[0])) - def cov_ema_update(self): """Update EMA of covariance for each laeyer. This function refers `self.rank_dict` to get sorted keys (linknames). """ - comm = self.communicator - if self.t_cov == 0: - self.cov_ema_dict = self.allocate_matrices() - # ======== Communication - if comm is not None: - is_done = comm.sendrecv_param(self) - if is_done: - return True for i, linkname in enumerate(sorted(self.rank_dict.keys())): self._cov_ema_update_core(linkname) - # ======== Communication - if comm is not None: - comm.sendrecv_cov_ema(self.cov_ema_dict) - self.t_inv += 1 self.t_cov += 1 def _cov_ema_update_core(self, linkname): @@ -429,9 +359,6 @@ def _cov_ema_update_core(self, linkname): else: raise ValueError('Invalid or unsupported shape: {}.'.format( acts.shape)) - # ======== Communication - if comm is not None: - comm.allreduce_cov(covs) if linkname in self.cov_ema_dict.keys(): alpha = self.hyperparam.cov_ema_decay cov_emas = self.cov_ema_dict[linkname] @@ -447,32 +374,13 @@ def inv_update(self): This function refers `self.cov_ema_dict`. """ comm = self.communicator - # ======== Communication - if comm is not None: - if self.t_inv == 0 and not comm.is_cov_worker: - self.cov_ema_dict = self.allocate_matrices() - comm.sendrecv_cov_ema(self.cov_ema_dict) - - if comm is not None and len(comm.inv_worker_ranks) > 1: - index = comm.inv_worker_ranks.index(comm.wcomm.rank) - keys = numpy.array(sorted(list(self.cov_ema_dict.keys()))) - keys = numpy.array_split(keys, len(comm.inv_worker_ranks)) - my_keys = list(keys[index]) - self.inv_dict = self.allocate_matrices() - else: - my_keys = list(self.cov_ema_dict.keys()) - for key in my_keys: - linkname = key - emas = self.cov_ema_dict[key] + divided_linknames = get_divided_linknames(self.target, comm.size) + for linkname in divided_linknames[comm.rank]: + emas = self.cov_ema_dict[linkname] self._inv_update_core(linkname, emas) self.t_inv += 1 - # ======== Communication - if comm is not None: - is_done = comm.bcast_inv(self.inv_dict) - if is_done: - return True def _inv_update_core(self, linkname, emas): """Update the value of `self.inv_dict[linkname]`. diff --git a/dlframeworks/chainer/utils/__init__.py b/dlframeworks/chainer/utils/__init__.py index 93f16a4..9b8f81b 100644 --- a/dlframeworks/chainer/utils/__init__.py +++ b/dlframeworks/chainer/utils/__init__.py @@ -4,6 +4,7 @@ from dlframeworks.chainer.utils.debug import create_mpi_print # NOQA +from dlframeworks.chainer.utils.link import get_divided_linknames # NOQA from dlframeworks.chainer.utils.link import get_link # NOQA from dlframeworks.chainer.utils.link import get_linknames # NOQA from dlframeworks.chainer.utils.link import get_param # NOQA diff --git a/dlframeworks/chainer/utils/link.py b/dlframeworks/chainer/utils/link.py index 7d19a11..0ffe69f 100644 --- a/dlframeworks/chainer/utils/link.py +++ b/dlframeworks/chainer/utils/link.py @@ -1,7 +1,13 @@ +import numpy as np + + def get_linknames(model): linknames = set() for paramname, _ in model.namedparams(): linkname = paramname[:paramname.rfind('/')] + # TODO(Yohei): + if 'bn' in linkname: + continue linknames.add(linkname) return list(linknames) @@ -16,3 +22,8 @@ def get_param(model, name): for paramname, param in model.namedparams(): if paramname == name: return param + + +def get_divided_linknames(model, size): + linknames = sorted(get_linknames(model)) + return np.array_split(linknames, size) diff --git a/examples/kfac/config/kfc.yaml b/examples/kfac/config/kfc.yaml index 8e187c2..3c3efef 100644 --- a/examples/kfac/config/kfc.yaml +++ b/examples/kfac/config/kfc.yaml @@ -1,4 +1,4 @@ -out: "${HOME}/results/dlframeworks" +out: "/home/users/Yohei/results/dlframeworks" train: /home/share/ILSVRC2012/labels/train008.txt train_root: /home/share/ILSVRC2012/ILSVRC2012_img_train val: /home/share/ILSVRC2012/labels/val008.txt diff --git a/examples/kfac/main.py b/examples/kfac/main.py index 522e3e6..2e5b852 100644 --- a/examples/kfac/main.py +++ b/examples/kfac/main.py @@ -84,10 +84,8 @@ def main(): parser.set_defaults(test=False) args = parser.parse_args() - comm = dlframeworks.chainer.communicators.KFACCommunicator( - args.communicator, npergroup=args.npergroup, debug=True, timeout=300, - join_cov=args.join_cov, n_cov_workers=args.n_cov_workers) - device = comm.wcomm.intra_rank # GPU is related with intra rank + comm = dlframeworks.chainer.communicators.create_communicator(debug=True) + device = comm.intra_rank # GPU is related with intra rank chainer.cuda.get_device_from_id(device).use() model = archs[args.arch]() @@ -103,19 +101,19 @@ def main(): try: model.to_gpu() except chainer.cuda.cupy.cuda.runtime.CUDARuntimeError as e: - print('Error occured in {}'.format(comm.wcomm.rank), file=sys.stderr) + print('Error occured in {}'.format(comm.rank), file=sys.stderr) raise e - if comm.wcomm.mpi_comm.rank == 0: + if comm.mpi_comm.rank == 0: print('==========================================') - print('Num process (COMM_WORLD): {}'.format(comm.wcomm.mpi_comm.size)) + print('Num process (COMM_WORLD): {}'.format(comm.mpi_comm.size)) print('Using {} communicator'.format(args.communicator)) print('Using {} arch'.format(args.arch)) print('Num Minibatch-size: {}'.format(args.batchsize)) print('Num epoch: {}'.format(args.epoch)) print('==========================================') - comm.wcomm.mpi_comm.Barrier() + comm.mpi_comm.Barrier() # ======== Create optimizer ======== optimizer = dlframeworks.chainer.optimizers.KFAC( @@ -130,109 +128,77 @@ def main(): optimizer.setup(model) optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay)) - if comm.is_grad_worker or comm.is_cov_worker: - if comm.is_grad_worker: - # Gradient worker - # Load all dataset in memory - dataset_class = dlframeworks.chainer.datasets.CroppingDatasetIO - sub_comm = comm.gcomm - batchsize = args.batchsize - else: - # Covariance worker - # Load dataset in memory when needed - dataset_class = dlframeworks.chainer.datasets.CroppingDatasetIO - sub_comm = comm.ccomm - batchsize = args.cov_batchsize - - mean = np.load(args.mean) - - # ======== Create dataset ======== - if comm.gcomm.rank == 0 or comm.ccomm.rank == 0: - train = dlframeworks.chainer.datasets.read_pairs(args.train) - val = dlframeworks.chainer.datasets.read_pairs(args.val) - else: - train = None - val = None - train = chainermn.scatter_dataset(train, sub_comm, shuffle=True) - val = chainermn.scatter_dataset(val, sub_comm) - train_dataset = dataset_class( - train, args.train_root, mean, model.insize, model.insize) - val_dataset = dataset_class( - val, args.val_root, mean, model.insize, model.insize) - - # ======== Create iterator ======== - if args.iterator == 'process': - multiprocessing.set_start_method('forkserver') - train_iterator = chainer.iterators.MultiprocessIterator( - train_dataset, batchsize, n_processes=args.loaderjob) - val_iterator = chainer.iterators.MultiprocessIterator( - val_dataset, args.val_batchsize, n_processes=args.loaderjob, - repeat=False) - elif args.iterator == 'thread': - train_iterator = chainer.iterators.MultithreadIterator( - train_dataset, batchsize, n_threads=args.loaderjob) - val_iterator = chainer.iterators.MultithreadIterator( - val_dataset, args.val_batchsize, n_threads=args.loaderjob, - repeat=False) - else: - train_iterator = chainer.iterators.SerialIterator(train_dataset, batchsize) - val_iterator = chainer.iterators.SerialIterator(val_dataset, args.val_batchsize, - repeat=False, shuffle=False) - - # ======== Create updater ======== - updater = training.StandardUpdater(train_iterator, optimizer, - device=device) - - # ======== Create trainer ======== - if comm.is_cov_worker: - def stop_trigger(x): - if x.updater.get_optimizer('main').is_done: - return True - else: - return False - trainer = training.Trainer(updater, stop_trigger, args.out) - else: - trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.out) - - # ======== Extend trainer ======== - val_interval = (10, 'iteration') if args.test else (1, 'epoch') - log_interval = (10, 'iteration') if args.test else (1, 'epoch') - if comm.is_grad_worker: - # Only gradient worker needs to join this extension - # Evaluator - evaluator = TestModeEvaluator(val_iterator, model, device=device) - evaluator = chainermn.create_multi_node_evaluator(evaluator, comm.gcomm) - trainer.extend(evaluator, trigger=val_interval) - - # Some display and output extensions are necessary only for one worker. - # (Otherwise, there would just be repeated outputs.) - if comm.gcomm.rank == 0: - trainer.extend(extensions.dump_graph('main/loss')) - trainer.extend(extensions.LogReport(trigger=log_interval)) - trainer.extend(extensions.observe_lr(), trigger=log_interval) - trainer.extend(observe_hyperparam('momentum'), trigger=log_interval) - trainer.extend(observe_hyperparam('cov_ema_decay'), trigger=log_interval) - trainer.extend(observe_hyperparam('inv_freq'), trigger=log_interval) - trainer.extend(observe_hyperparam('damping'), trigger=log_interval) - trainer.extend(extensions.PrintReport([ - 'epoch', 'iteration', 'main/loss', 'validation/main/loss', - 'main/accuracy', 'validation/main/accuracy', 'lr' - ]), trigger=log_interval) - trainer.extend(extensions.ProgressBar(update_interval=10)) - - if args.resume: - chainer.serializers.load_npz(args.resume, trainer) - - trainer.run() + batchsize = args.batchsize + # Load all dataset in memory + dataset_class = dlframeworks.chainer.datasets.CroppingDatasetIO + mean = np.load(args.mean) + # ======== Create dataset ======== + if comm.rank == 0: + train = dlframeworks.chainer.datasets.read_pairs(args.train) + val = dlframeworks.chainer.datasets.read_pairs(args.val) else: - # Inverse worker - # ======== Create optimizer ======== - while True: - optimizer.update() - if optimizer.is_done: - break - print('Inverse done') + train = None + val = None + train = chainermn.scatter_dataset(train, comm, shuffle=True) + val = chainermn.scatter_dataset(val, comm) + train_dataset = dataset_class( + train, args.train_root, mean, model.insize, model.insize) + val_dataset = dataset_class( + val, args.val_root, mean, model.insize, model.insize) + + + # ======== Create iterator ======== + if args.iterator == 'process': + multiprocessing.set_start_method('forkserver') + train_iterator = chainer.iterators.MultiprocessIterator( + train_dataset, batchsize, n_processes=args.loaderjob) + val_iterator = chainer.iterators.MultiprocessIterator( + val_dataset, args.val_batchsize, n_processes=args.loaderjob, + repeat=False) + elif args.iterator == 'thread': + train_iterator = chainer.iterators.MultithreadIterator( + train_dataset, batchsize, n_threads=args.loaderjob) + val_iterator = chainer.iterators.MultithreadIterator( + val_dataset, args.val_batchsize, n_threads=args.loaderjob, + repeat=False) + else: + train_iterator = chainer.iterators.SerialIterator(train_dataset, batchsize) + val_iterator = chainer.iterators.SerialIterator(val_dataset, args.val_batchsize, + repeat=False, shuffle=False) + + # ======== Create updater ======== + updater = training.StandardUpdater(train_iterator, optimizer, + device=device) + + # ======== Create trainer ======== + trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.out) + + # ======== Extend trainer ======== + val_interval = (10, 'iteration') if args.test else (1, 'epoch') + log_interval = (10, 'iteration') if args.test else (1, 'epoch') + + # Some display and output extensions are necessary only for one worker. + # (Otherwise, there would just be repeated outputs.) + if comm.rank == 0: + trainer.extend(extensions.dump_graph('main/loss')) + trainer.extend(extensions.LogReport(trigger=log_interval)) + trainer.extend(extensions.observe_lr(), trigger=log_interval) + trainer.extend(observe_hyperparam('momentum'), trigger=log_interval) + trainer.extend(observe_hyperparam('cov_ema_decay'), trigger=log_interval) + trainer.extend(observe_hyperparam('inv_freq'), trigger=log_interval) + trainer.extend(observe_hyperparam('damping'), trigger=log_interval) + trainer.extend(extensions.PrintReport([ + 'epoch', 'iteration', 'main/loss', 'validation/main/loss', + 'main/accuracy', 'validation/main/accuracy', 'lr' + ]), trigger=log_interval) + trainer.extend(extensions.ProgressBar(update_interval=10)) + + if args.resume: + chainer.serializers.load_npz(args.resume, trainer) + + trainer.run() if __name__ == '__main__': main() + diff --git a/examples/kfac/parse_ops.py b/examples/kfac/parse_ops.py index ffc00a8..ad85a82 100644 --- a/examples/kfac/parse_ops.py +++ b/examples/kfac/parse_ops.py @@ -14,7 +14,7 @@ def parse_ops(ops): - hostname = gethostname().strip() + hostname = gethostname()[:3] if hostname == 'kfc': script = """\ #!/bin/sh diff --git a/examples/kfac/parse_ops_core.py b/examples/kfac/parse_ops_core.py index 4271a49..82caae3 100644 --- a/examples/kfac/parse_ops_core.py +++ b/examples/kfac/parse_ops_core.py @@ -1,6 +1,6 @@ import datetime import json -import socket +from socket import gethostname import shutil import yaml @@ -26,7 +26,7 @@ def get_time(fmt='%y.%m.%d_%H.%M.%S'): def get_npernode(ops): - hostname = socket.gethostname().strip() + hostname = gethostname()[:3] if hostname == 'kfc': return 8 else: