|
1 | 1 | from datetime import datetime
|
2 | 2 |
|
| 3 | +from pandas import DataFrame, MultiIndex |
| 4 | + |
3 | 5 | import pytest
|
4 | 6 |
|
5 | 7 | import pandas_datareader.data as web
|
@@ -43,21 +45,22 @@ def test_single_symbol(self):
|
43 | 45 | def test_multiple_symbols(self):
|
44 | 46 | syms = ["AAPL", "MSFT", "TSLA"]
|
45 | 47 | df = web.DataReader(syms, "iex", self.start, self.end)
|
46 |
| - assert sorted(list(df)) == syms |
| 48 | + assert sorted(list(df.columns.levels[1])) == syms |
47 | 49 | for sym in syms:
|
48 |
| - assert len(df[sym] == 578) |
| 50 | + assert len(df.xs(sym, level='Symbols', axis=1) == 578) |
49 | 51 |
|
50 | 52 | def test_multiple_symbols_2(self):
|
51 | 53 | syms = ["AAPL", "MSFT", "TSLA"]
|
52 | 54 | good_start = datetime(2017, 2, 9)
|
53 | 55 | good_end = datetime(2017, 5, 24)
|
54 | 56 | df = web.DataReader(syms, "iex", good_start, good_end)
|
55 |
| - assert isinstance(df, dict) |
56 |
| - assert len(df) == 3 |
57 |
| - assert sorted(list(df)) == syms |
| 57 | + assert isinstance(df, DataFrame) |
| 58 | + assert isinstance(df.columns, MultiIndex) |
| 59 | + assert len(df.columns.levels[1]) == 3 |
| 60 | + assert sorted(list(df.columns.levels[1])) == syms |
58 | 61 |
|
59 |
| - a = df["AAPL"] |
60 |
| - t = df["TSLA"] |
| 62 | + a = df.xs("AAPL", axis=1, level='Symbols') |
| 63 | + t = df.xs("TSLA", axis=1, level='Symbols') |
61 | 64 |
|
62 | 65 | assert len(a) == 73
|
63 | 66 | assert len(t) == 73
|
|
0 commit comments