Skip to content

Commit 1382b89

Browse files
committed
update IRM ATT estimation
1 parent 4b49ecd commit 1382b89

File tree

2 files changed

+2
-6
lines changed

2 files changed

+2
-6
lines changed

doubleml/irm/apo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def _score_elements(self, y, treated, g_hat_d_lvl0, g_hat_d_lvl1, m_hat, smpls):
318318
u_hat = y - g_hat_d_lvl1
319319
weights, weights_bar = self._get_weights()
320320
psi_b = weights * g_hat_d_lvl1 + weights_bar * np.divide(np.multiply(treated, u_hat), m_hat_adj)
321-
psi_a = -1.0 * np.divide(weights, np.mean(weights)) # TODO: check if this is correct
321+
psi_a = -1.0 * np.divide(weights, np.mean(weights))
322322

323323
return psi_a, psi_b
324324

doubleml/irm/irm.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,7 @@ def _score_elements(self, y, d, g_hat0, g_hat1, m_hat, smpls):
354354
psi_b = weights * (g_hat1 - g_hat0) + weights_bar * (
355355
np.divide(np.multiply(d, u_hat1), m_hat_adj) - np.divide(np.multiply(1.0 - d, u_hat0), 1.0 - m_hat_adj)
356356
)
357-
if self.score == "ATE":
358-
psi_a = -1.0 * np.divide(weights, np.mean(weights)) # TODO: check if this is correct
359-
else:
360-
assert self.score == "ATTE"
361-
psi_a = -1.0 * weights
357+
psi_a = -1.0 * np.divide(weights, np.mean(weights))
362358
else:
363359
assert callable(self.score)
364360
psi_a, psi_b = self.score(y=y, d=d, g_hat0=g_hat0, g_hat1=g_hat1, m_hat=m_hat_adj, smpls=smpls)

0 commit comments

Comments
 (0)