diff --git a/pandas/tools/merge.py b/pandas/tools/merge.py index fa9517dadd432..d64d79913b10f 100644 --- a/pandas/tools/merge.py +++ b/pandas/tools/merge.py @@ -628,8 +628,10 @@ def _prepare_blocks(self): for unit in self.units: join_blocks = unit.get_upcasted_blocks() - type_map = dict((type(blk), blk) for blk in join_blocks) - blockmaps.append(type_map) + type_map = {} + for blk in join_blocks: + type_map.setdefault(type(blk), []).append(blk) + blockmaps.append((unit, type_map)) return blockmaps @@ -640,26 +642,22 @@ def get_result(self): merged : BlockManager """ blockmaps = self._prepare_blocks() - kinds = _get_all_block_kinds(blockmaps) + kinds = _get_merge_block_kinds(blockmaps) result_blocks = [] # maybe want to enable flexible copying <-- what did I mean? for klass in kinds: - klass_blocks = [mapping.get(klass) for mapping in blockmaps] + klass_blocks = [] + for unit, mapping in blockmaps: + if klass in mapping: + klass_blocks.extend((unit, b) for b in mapping[klass]) res_blk = self._get_merged_block(klass_blocks) result_blocks.append(res_blk) return BlockManager(result_blocks, self.result_axes) - def _get_merged_block(self, blocks): - - to_merge = [] - - for unit, block in zip(self.units, blocks): - if block is not None: - to_merge.append((unit, block)) - + def _get_merged_block(self, to_merge): if len(to_merge) > 1: return self._merge_blocks(to_merge) else: @@ -682,7 +680,8 @@ def _merge_blocks(self, merge_chunks): out_shape[self.axis] = n # Should use Fortran order?? - out = np.empty(out_shape, dtype=fblock.values.dtype) + block_dtype = _get_block_dtype([x[1] for x in merge_chunks]) + out = np.empty(out_shape, dtype=block_dtype) sofar = 0 for unit, blk in merge_chunks: @@ -787,6 +786,25 @@ def _get_all_block_kinds(blockmaps): kinds |= set(mapping) return kinds +def _get_merge_block_kinds(blockmaps): + kinds = set() + for _, mapping in blockmaps: + kinds |= set(mapping) + return kinds + +def _get_block_dtype(blocks): + if len(blocks) == 0: + return object + blk1 = blocks[0] + dtype = blk1.dtype + + if issubclass(dtype.type, np.floating): + for blk in blocks: + if blk.dtype.type == np.float64: + return blk.dtype + + return dtype + #---------------------------------------------------------------------- # Concatenate DataFrame objects @@ -928,16 +946,20 @@ def get_result(self): def _get_fresh_axis(self): return Index(np.arange(len(self._get_concat_axis()))) + def _prepare_blocks(self): + reindexed_data = self._get_reindexed_data() + + blockmaps = [] + for data in reindexed_data: + data = data.consolidate() + type_map = dict((type(blk), blk) for blk in data.blocks) + blockmaps.append(type_map) + return blockmaps + def _get_concatenated_data(self): try: # need to conform to same other (joined) axes for block join - reindexed_data = self._get_reindexed_data() - - blockmaps = [] - for data in reindexed_data: - data = data.consolidate() - type_map = dict((type(blk), blk) for blk in data.blocks) - blockmaps.append(type_map) + blockmaps = self._prepare_blocks() kinds = _get_all_block_kinds(blockmaps) new_blocks = [] diff --git a/pandas/tools/tests/test_merge.py b/pandas/tools/tests/test_merge.py index 38db2f2b602ea..829471deb9e6d 100644 --- a/pandas/tools/tests/test_merge.py +++ b/pandas/tools/tests/test_merge.py @@ -414,6 +414,19 @@ def test_join_float64_float32(self): expected = a.join(b.astype('f8')) assert_frame_equal(joined, expected) + joined = b.join(a) + assert_frame_equal(expected, joined.reindex(columns=['a', 'b', 'c'])) + + a = np.random.randint(0, 5, 100) + b = np.random.random(100).astype('Float64') + c = np.random.random(100).astype('Float32') + df = DataFrame({'a': a, 'b' : b, 'c' : c}) + xpdf = DataFrame({'a': a, 'b' : b, 'c' : c.astype('Float64')}) + s = DataFrame(np.random.random(5).astype('f'), columns=['md']) + rs = df.merge(s, left_on='a', right_index=True) + xp = xpdf.merge(s.astype('f8'), left_on='a', right_index=True) + assert_frame_equal(rs, xp) + def test_join_many_non_unique_index(self): df1 = DataFrame({"a": [1,1], "b": [1,1], "c": [10,20]}) df2 = DataFrame({"a": [1,1], "b": [1,2], "d": [100,200]}) @@ -1466,5 +1479,3 @@ def test_multigroup(self): import nose nose.runmodule(argv=[__file__,'-vvs','-x','--pdb', '--pdb-failure'], exit=False) - -