Skip to content

Commit 3c468d1

Browse files
committed
Detect the program name when python -m was executed
1 parent 51afa53 commit 3c468d1

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

fire/core.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def Fire(component=None, command=None, name=None, serialize=None):
109109
code 2. When used with the help or trace flags, Fire will raise a
110110
FireExit with code 0 if successful.
111111
"""
112-
name = name or os.path.basename(sys.argv[0])
112+
113+
name = _GetProgName(name)
113114

114115
# Get args as a list.
115116
if isinstance(command, six.string_types):
@@ -281,6 +282,45 @@ def _PrintResult(component_trace, verbose=False, serialize=None):
281282
Display(output, out=sys.stdout)
282283

283284

285+
def _GetProgName(name, main=None):
286+
"""Determines the program name.
287+
288+
This function returns the program name that should be
289+
displayed.
290+
291+
If ``python -m`` was used to execute a module, ``python -m name`` will be
292+
returned, instead of ``__main__.py``.
293+
294+
Args:
295+
name: Optional. The name of the command as entered at the command line.
296+
main: Optional. This should only be passed during testing.
297+
Returns:
298+
The program name determined by this function.
299+
"""
300+
if name:
301+
return name
302+
303+
name_from_arg = os.path.basename(sys.argv[0])
304+
305+
if main:
306+
py_module = main
307+
else:
308+
py_module = sys.modules['__main__'].__package__ # pylint: disable=no-member
309+
310+
if py_module is not None:
311+
if name_from_arg == '__main__.py':
312+
return '{executable} -m {module}'.format(
313+
executable=sys.executable, module=py_module)
314+
else:
315+
# For example: python -m sample.cli
316+
name = os.path.splitext(name_from_arg)[0]
317+
py_module = '{module}.{name}'.format(module=py_module, name=name)
318+
return '{executable} -m {module}'.format(
319+
executable=sys.executable, module=py_module.lstrip('.'))
320+
else:
321+
return name_from_arg
322+
323+
284324
def _DisplayError(component_trace):
285325
"""Prints the Fire trace and the error to stdout."""
286326
result = component_trace.GetResult()

fire/fire_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import fire
2525
from fire import test_components as tc
2626
from fire import testutils
27+
from fire.core import _GetProgName
2728

2829
import mock
2930
import six
@@ -63,6 +64,17 @@ def testFireDefaultName(self):
6364
stderr=None):
6465
fire.Fire(tc.Empty)
6566

67+
def testFireProgName(self):
68+
# ``sample/__main__.py` would be argv[0] when ``python -m sample`` was
69+
# executed.
70+
with mock.patch.object(sys, 'argv', ['sample/__main__.py']):
71+
self.assertEqual(_GetProgName('', main='sample'),
72+
'{executable} -m sample'.format(executable=sys.executable))
73+
74+
with mock.patch.object(sys, 'argv', ['sample/cli.py']):
75+
self.assertEqual(_GetProgName('', main='sample'),
76+
'{executable} -m sample.cli'.format(executable=sys.executable))
77+
6678
def testFireNoArgs(self):
6779
self.assertEqual(fire.Fire(tc.MixedDefaults, command=['ten']), 10)
6880

0 commit comments

Comments
 (0)