Skip to content

Commit c114e5e

Browse files
authored
Add weighted to plot (#72)
* Add weighted to plot * bump version
1 parent 80b5c10 commit c114e5e

File tree

6 files changed

+52
-7
lines changed

6 files changed

+52
-7
lines changed

CITATION.cff

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ cff-version: 1.2.0
22
message: "If you use this software, please cite it as below."
33
type: software
44
title: "dte_adj: A Python Package for Estimating Distribution Treatment Effects"
5-
version: 0.1.7
5+
version: 0.1.8
66
date-released: 2024-12-01
77
url: "https://github.com/CyberAgentAILab/python-dte-adjustment"
88
repository-code: "https://github.com/CyberAgentAILab/python-dte-adjustment"

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
project = "dte_adj"
1414
copyright = "2024, CyberAgent, Inc."
1515
author = "CyberAgent, Inc"
16-
release = "0.1.7"
16+
release = "0.1.8"
1717

1818
# -- General configuration ---------------------------------------------------
1919
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

dte_adj/plot/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def plot(
1515
title: Optional[str] = None,
1616
xlabel: Optional[str] = None,
1717
ylabel: Optional[str] = None,
18+
weighted: bool = False,
1819
):
1920
"""Visualize distributional parameters and their confidence intervals.
2021
@@ -29,12 +30,18 @@ def plot(
2930
title (str, optional): Axes title.
3031
xlabel (str, optional): X-axis title label.
3132
ylabel (str, optional): Y-axis title label.
33+
weighted (bool, optional): If True, multiply treatment effects by X values to show value-weighted effects. Defaults to False.
3234
3335
Returns:
3436
matplotlib.axes.Axes: The axes with the plot.
3537
"""
3638
if ax is None:
37-
fig, ax = plt.subplots()
39+
_, ax = plt.subplots()
40+
41+
if weighted:
42+
means = means * X
43+
lower_bounds = lower_bounds * X
44+
upper_bounds = upper_bounds * X
3845

3946
if chart_type == "line":
4047
ax.plot(X, means, label="Values", color=color)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "dte_adj"
7-
version = "0.1.7"
7+
version = "0.1.8"
88
description = "This is a Python library for estimating distributional treatment effects"
99
readme = "README.md"
1010
requires-python = ">=3.10"

tests/test_plot.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,44 @@ def test_plot_fail_unknown_chart_type(self):
7272
"Chart type other is not supported",
7373
)
7474

75+
@patch("dte_adj.plot.plt")
76+
def test_plot_weighted(self, mock_plt):
77+
# Arrange
78+
x_values = np.array([1, 2, 3, 4, 5])
79+
means = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
80+
upper_bands = np.array([0.2, 0.3, 0.4, 0.5, 0.6])
81+
lower_bands = np.array([0.0, 0.1, 0.2, 0.3, 0.4])
82+
mock_ax = MagicMock()
83+
mock_plt.subplots.return_value = (MagicMock(), mock_ax)
84+
85+
# Act
86+
result_ax = plot(
87+
x_values,
88+
means,
89+
lower_bands,
90+
upper_bands,
91+
chart_type="line",
92+
weighted=True,
93+
)
7594

76-
if __name__ == "__main__":
77-
unittest.main()
95+
# Assert
96+
self.assertEqual(result_ax, mock_ax)
97+
mock_plt.subplots.assert_called_once()
98+
plot_call = mock_ax.plot.call_args
99+
fill_between_call = mock_ax.fill_between.call_args
100+
101+
# Check that values are weighted (multiplied by x_values)
102+
plot_args, plot_kwargs = plot_call
103+
x_values_arg, y_values_arg = plot_args
104+
expected_weighted_means = means * x_values
105+
self.assertTrue(np.array_equal(x_values_arg, x_values))
106+
self.assertTrue(np.array_equal(y_values_arg, expected_weighted_means))
107+
108+
# Check that confidence intervals are also weighted
109+
fill_between_args, fill_between_kwargs = fill_between_call
110+
x_fill, lower_fill, upper_fill = fill_between_args
111+
expected_weighted_lower = lower_bands * x_values
112+
expected_weighted_upper = upper_bands * x_values
113+
self.assertTrue(np.array_equal(x_fill, x_values_arg))
114+
self.assertTrue(np.array_equal(lower_fill, expected_weighted_lower))
115+
self.assertTrue(np.array_equal(upper_fill, expected_weighted_upper))

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)