Skip to content

Commit 4d107c6

Browse files
committed
Increase test coverage
1 parent 0cb862b commit 4d107c6

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

patsy/polynomials.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class Poly(object):
3232
3333
.. versionadded:: 0.4.1
3434
"""
35+
3536
def __init__(self):
3637
self._tmp = {}
3738

@@ -110,8 +111,8 @@ def gen_qr(raw_poly, n):
110111
# Q is now orthognoal of degree n. To match what R is doing, we
111112
# need to use the three-term recurrence technique to calculate
112113
# new alpha, beta, and norm.
113-
alpha = (np.sum(x.reshape((-1, 1)) * q[:, :n] ** 2, axis=0) /
114-
np.sum(q[:, :n] ** 2, axis=0))
114+
alpha = (np.sum(x.reshape((-1, 1)) * q[:, :n] ** 2, axis=0)
115+
/ np.sum(q[:, :n] ** 2, axis=0))
115116

116117
# For reasons I don't understand, the norms R uses are based off
117118
# of the diagonal of the r upper triangular matrix.
@@ -137,6 +138,7 @@ def apply_qr(x, n, alpha, norm, beta):
137138
return p
138139
__getstate__ = no_pickling
139140

141+
140142
poly = stateful_transform(Poly)
141143

142144

@@ -145,6 +147,8 @@ def test_poly_compat():
145147
from patsy.test_poly_data import (R_poly_test_x,
146148
R_poly_test_data,
147149
R_poly_num_tests)
150+
from numpy.testing import assert_allclose
151+
148152
lines = R_poly_test_data.split("\n")
149153
tests_ran = 0
150154
start_idx = lines.index("--BEGIN TEST CASE--")
@@ -172,6 +176,14 @@ def test_poly_compat():
172176
output = np.asarray(eval(test_data["output"]))
173177
# Do the actual test
174178
check_stateful(Poly, False, R_poly_test_x, output, **kwargs)
179+
raw_poly = Poly.vander(R_poly_test_x, kwargs['degree'])
180+
if kwargs['raw']:
181+
actual = raw_poly[:, 1:]
182+
else:
183+
alpha, norm, beta = Poly.gen_qr(raw_poly, kwargs['degree'])
184+
actual = Poly.apply_qr(R_poly_test_x, kwargs['degree'], alpha,
185+
norm, beta)[:, 1:]
186+
assert_allclose(actual, output)
175187
tests_ran += 1
176188
# Set up for the next one
177189
start_idx = stop_idx + 1

0 commit comments

Comments
 (0)