Skip to content

Commit b16931c

Browse files
dlukestomaarsen
andauthored
Refactor dispersion plot (nltk#3082)
* Modernize Matplotlib code for dispersion plot - Use object-oriented instead of state-machine interface. - Return Axes object to allow additional customization (see nltk#2239). - Remove useless scalex kwarg (it's supposed to be a bool and True by default, so passing in 0.1 is confusing). * Use default palette in dispersion plot * Refactor data preparation in dispersion plot Make the code a bit more concise and readable for beginners, who may want to use it as a starting point for their own tweaked dispersion plot. (Incidentally, this version is also a bit faster since it replaces the nested loop over words with the in operator on a dict, but that's not the main goal.) * Casefold instead of lower in dispersion plot str.casefold is the method primarily meant for caseless comparison. * Dispersion plot docstring tweak * Reraise ImportError if importing matplotlib fails Rather than ValueError. Additionally, add a space between "... installed." and "See ..." * Add docstring for return value to dispersion plot Co-authored-by: Tom Aarsen <[email protected]>
1 parent f019fbe commit b16931c

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

nltk/draw/dispersion.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,51 +15,49 @@ def dispersion_plot(text, words, ignore_case=False, title="Lexical Dispersion Pl
1515
Generate a lexical dispersion plot.
1616
1717
:param text: The source text
18-
:type text: list(str) or enum(str)
18+
:type text: list(str) or iter(str)
1919
:param words: The target words
2020
:type words: list of str
2121
:param ignore_case: flag to set if case should be ignored when searching text
2222
:type ignore_case: bool
23+
:return: a matplotlib Axes object that may still be modified before plotting
24+
:rtype: Axes
2325
"""
2426

2527
try:
26-
from matplotlib import pylab
28+
import matplotlib.pyplot as plt
2729
except ImportError as e:
28-
raise ValueError(
29-
"The plot function requires matplotlib to be installed."
30+
raise ImportError(
31+
"The plot function requires matplotlib to be installed. "
3032
"See https://matplotlib.org/"
3133
) from e
3234

33-
text = list(text)
34-
words.reverse()
35-
36-
if ignore_case:
37-
words_to_comp = list(map(str.lower, words))
38-
text_to_comp = list(map(str.lower, text))
39-
else:
40-
words_to_comp = words
41-
text_to_comp = text
42-
43-
points = [
44-
(x, y)
45-
for x in range(len(text_to_comp))
46-
for y in range(len(words_to_comp))
47-
if text_to_comp[x] == words_to_comp[y]
48-
]
49-
if points:
50-
x, y = list(zip(*points))
51-
else:
52-
x = y = ()
53-
pylab.plot(x, y, "b|", scalex=0.1)
54-
pylab.yticks(list(range(len(words))), words, color="b")
55-
pylab.ylim(-1, len(words))
56-
pylab.title(title)
57-
pylab.xlabel("Word Offset")
58-
pylab.show()
35+
word2y = {
36+
word.casefold() if ignore_case else word: y
37+
for y, word in enumerate(reversed(words))
38+
}
39+
xs, ys = [], []
40+
for x, token in enumerate(text):
41+
token = token.casefold() if ignore_case else token
42+
y = word2y.get(token)
43+
if y is not None:
44+
xs.append(x)
45+
ys.append(y)
46+
47+
_, ax = plt.subplots()
48+
ax.plot(xs, ys, "|")
49+
ax.set_yticks(list(range(len(words))), words, color="C0")
50+
ax.set_ylim(-1, len(words))
51+
ax.set_title(title)
52+
ax.set_xlabel("Word Offset")
53+
return ax
5954

6055

6156
if __name__ == "__main__":
57+
import matplotlib.pyplot as plt
58+
6259
from nltk.corpus import gutenberg
6360

6461
words = ["Elinor", "Marianne", "Edward", "Willoughby"]
6562
dispersion_plot(gutenberg.words("austen-sense.txt"), words)
63+
plt.show()

0 commit comments

Comments
 (0)