Skip to content

Commit 2ffc631

Browse files
authored
Merge pull request #109 from codeNinja62/master
Refactor authentication, payload handling, and fix session/console bugs
2 parents b04ef99 + c191da8 commit 2ffc631

File tree

1 file changed

+120
-93
lines changed

1 file changed

+120
-93
lines changed

pymetasploit3/msfrpc.py

Lines changed: 120 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
'ConsoleManager'
5151
]
5252

53-
5453
class MsfRpcError(Exception):
5554
pass
5655

@@ -200,9 +199,25 @@ def __init__(self, password, **kwargs):
200199
self.encodings = kwargs.get('encodings', ['utf-8'])
201200
self.decode_error_handling: str = kwargs.get('decode_error_handling', 'strict')
202201
self.headers = {"Content-type": "binary/message-pack"}
202+
self.persistentlogin = True # If True, the login token will not be removed on logout
203+
203204
if self.token is None:
204205
self.login(kwargs.get('username', 'msf'), password)
205206

207+
def __persistentlogin__(self, value=None):
208+
"""
209+
Get or set whether the login token should be removed on logout.
210+
If called without arguments, returns the current value.
211+
If called with a boolean, sets the value.
212+
If called with a Number, raises TypeError.
213+
"""
214+
if value is None:
215+
return self.persistentlogin
216+
if isinstance(value, bool):
217+
self.persistentlogin = value
218+
elif isinstance(value, Number):
219+
raise TypeError("Expected a boolean value for persistentlogin")
220+
206221
def call(self, method, opts=None, is_raw=False):
207222
if not isinstance(opts, list):
208223
opts = []
@@ -247,17 +262,18 @@ def login(self, user, password):
247262

248263
def add_perm_token(self):
249264
"""
250-
Add a permanent UUID4 API token
265+
Add a UUID4 API token
251266
"""
252267
token = str(uuid.uuid4())
253268
self.call(MsfRpcMethod.AuthTokenAdd, [token])
254269
return token
255270

256271
def logout(self):
257-
"""
258-
Logs the current user out. Note: do not call directly.
259-
"""
260-
self.call(MsfRpcMethod.AuthLogout, [self.token])
272+
"""Logs out and removes the API token"""
273+
if self.token and not self.__persistentlogin__:
274+
self.call(MsfRpcMethod.AuthLogout, [self.token])
275+
self.call(MsfRpcMethod.AuthTokenRemove, [self.token])
276+
self.token = None
261277

262278
@property
263279
def core(self):
@@ -1421,6 +1437,39 @@ def payload_generate(self, **kwargs):
14211437
except (msgpack.exceptions.ExtraData, UnicodeDecodeError):
14221438
return payload
14231439
return payload
1440+
1441+
def _handle_payload(self, runopts, payload):
1442+
"""Centralized payload handling for exploits"""
1443+
if not payload:
1444+
if 'DisablePayloadHandler' not in runopts or not runopts['DisablePayloadHandler']:
1445+
runopts['DisablePayloadHandler'] = True
1446+
return runopts
1447+
1448+
# Only ExploitModules should handle payloads
1449+
if not isinstance(self, ExploitModule):
1450+
return runopts # Silently ignore payload for non-exploits
1451+
1452+
if isinstance(payload, PayloadModule):
1453+
if payload.modulename not in self.payloads:
1454+
raise ValueError(f'Invalid payload ({payload.modulename}) for target ({self.target})')
1455+
runopts['PAYLOAD'] = payload.modulename
1456+
1457+
# Merge payload options without overwriting existing settings
1458+
for k, v in payload.runoptions.items():
1459+
if v is None or (isinstance(v, str) and not v):
1460+
continue
1461+
if k not in runopts or not runopts[k]:
1462+
runopts[k] = v
1463+
1464+
elif isinstance(payload, str):
1465+
if payload not in self.payloads:
1466+
raise ValueError(f'Invalid payload ({payload}) for target ({self.target})')
1467+
runopts['PAYLOAD'] = payload
1468+
1469+
else:
1470+
raise TypeError(f"Expected PayloadModule or str, got {type(payload).__name__}")
1471+
1472+
return runopts
14241473

14251474
def execute(self, **kwargs):
14261475
"""
@@ -1431,35 +1480,17 @@ def execute(self, **kwargs):
14311480
- **kwargs : can contain any module options.
14321481
"""
14331482
runopts = self.runoptions.copy()
1483+
1484+
# Handle exploit-specific payload logic
14341485
if isinstance(self, ExploitModule):
1435-
payload = kwargs.get('payload')
14361486
runopts['TARGET'] = self.target
1437-
if 'DisablePayloadHandler' in runopts and runopts['DisablePayloadHandler']:
1438-
pass
1439-
elif payload is None:
1440-
runopts['DisablePayloadHandler'] = True
1441-
else:
1442-
if isinstance(payload, PayloadModule):
1443-
if payload.modulename not in self.payloads:
1444-
raise ValueError(
1445-
'Invalid payload (%s) for given target (%d).' % (payload.modulename, self.target)
1446-
)
1447-
runopts['PAYLOAD'] = payload.modulename
1448-
for k, v in payload.runoptions.items():
1449-
if v is None or (isinstance(v, str) and not v):
1450-
continue
1451-
if k not in runopts or runopts[k] is None or \
1452-
(isinstance(runopts[k], str) and not runopts[k]):
1453-
runopts[k] = v
1454-
# runopts.update(payload.runoptions)
1455-
elif isinstance(payload, str):
1456-
if payload not in self.payloads:
1457-
raise ValueError('Invalid payload (%s) for given target (%d).' % (payload, self.target))
1458-
runopts['PAYLOAD'] = payload
1459-
else:
1460-
raise TypeError("Expected type str or PayloadModule not '%s'" % type(kwargs['payload']).__name__)
1461-
1462-
return self.rpc.call(MsfRpcMethod.ModuleExecute, [self.moduletype, self.modulename, runopts])
1487+
runopts = self._handle_payload(runopts, kwargs.get('payload'))
1488+
1489+
# Merge additional runtime options
1490+
runopts.update({k: v for k, v in kwargs.items() if k != 'payload'})
1491+
1492+
return self.rpc.call(MsfRpcMethod.ModuleExecute,
1493+
[self.moduletype, self.modulename, runopts])
14631494

14641495
def check(self, **kwargs):
14651496
"""
@@ -1468,37 +1499,26 @@ def check(self, **kwargs):
14681499
Optional Keyword Arguments:
14691500
- **kwargs : can contain any module options.
14701501
"""
1502+
# Create a copy of module options
14711503
runopts = self.runoptions.copy()
1504+
1505+
# First merge user-provided options (except payload)
1506+
user_opts = {k: v for k, v in kwargs.items() if k != 'payload'}
1507+
runopts.update(user_opts)
1508+
1509+
# Handle exploit-specific logic
14721510
if isinstance(self, ExploitModule):
1473-
payload = kwargs.get('payload')
1511+
# Ensure target is set (user options might override)
14741512
runopts['TARGET'] = self.target
1475-
if 'DisablePayloadHandler' in runopts and runopts['DisablePayloadHandler']:
1476-
pass
1477-
elif payload is None:
1478-
runopts['DisablePayloadHandler'] = True
1479-
else:
1480-
if isinstance(payload, PayloadModule):
1481-
if payload.modulename not in self.payloads:
1482-
raise ValueError(
1483-
'Invalid payload (%s) for given target (%d).' % (payload.modulename, self.target)
1484-
)
1485-
runopts['PAYLOAD'] = payload.modulename
1486-
for k, v in payload.runoptions.items():
1487-
if v is None or (isinstance(v, str) and not v):
1488-
continue
1489-
if k not in runopts or runopts[k] is None or \
1490-
(isinstance(runopts[k], str) and not runopts[k]):
1491-
runopts[k] = v
1492-
# runopts.update(payload.runoptions)
1493-
elif isinstance(payload, str):
1494-
if payload not in self.payloads:
1495-
raise ValueError('Invalid payload (%s) for given target (%d).' % (payload, self.target))
1496-
runopts['PAYLOAD'] = payload
1497-
else:
1498-
raise TypeError("Expected type str or PayloadModule not '%s'" % type(kwargs['payload']).__name__)
1499-
1500-
return self.rpc.call(MsfRpcMethod.ModuleCheck, [self.moduletype, self.modulename, runopts])
1501-
1513+
1514+
# Process payload if provided
1515+
runopts = self._handle_payload(runopts, kwargs.get('payload'))
1516+
1517+
# Always disable payload handler for checks
1518+
runopts['DisablePayloadHandler'] = True
1519+
1520+
return self.rpc.call(MsfRpcMethod.ModuleCheck,
1521+
[self.moduletype, self.modulename, runopts])
15021522

15031523
class ExploitModule(MsfModule):
15041524

@@ -2146,34 +2166,41 @@ def gather_output(self, cmd, end_strs, timeout):
21462166

21472167
class SessionManager(MsfManager):
21482168

2169+
21492170
@property
21502171
def list(self):
21512172
"""
21522173
A list of active sessions.
2153-
"""
2154-
return {str(k): v for k, v in self.rpc.call(MsfRpcMethod.SessionList).items()} # Convert int id to str
2155-
2156-
def session(self, sid):
2157-
"""
2158-
Returns a session object for meterpreter or shell sessions.
21592174
2160-
Mandatory Arguments:
2161-
- sid : the session identifier or uuid
2175+
Return session list with native integer IDs
21622176
"""
2177+
return self.rpc.call(MsfRpcMethod.SessionList)
2178+
2179+
def session(self, sid):
21632180
s = self.list
2164-
if sid not in s:
2165-
for k in s:
2166-
if s[k]['uuid'] == sid:
2167-
if s[k]['type'] == 'meterpreter':
2168-
return MeterpreterSession(k, self.rpc, s)
2169-
elif s[k]['type'] == 'shell':
2170-
return ShellSession(k, self.rpc, s)
2171-
raise KeyError('Session ID (%s) does not exist' % sid)
2172-
if s[sid]['type'] == 'meterpreter':
2173-
return MeterpreterSession(sid, self.rpc, s)
2174-
elif s[sid]['type'] == 'shell':
2175-
return ShellSession(sid, self.rpc, s)
2176-
raise NotImplementedError('Could not determine session type: %s' % s[sid]['type'])
2181+
# Handle UUID lookup (string)
2182+
for session_id, session_data in s.items():
2183+
if session_data.get('uuid') == sid:
2184+
return self._create_session(session_id, s, session_data)
2185+
2186+
# Handle integer ID
2187+
try:
2188+
int_id = int(sid) if isinstance(sid, str) else sid
2189+
if int_id in s:
2190+
return self._create_session(int_id, s, s[int_id])
2191+
except (ValueError, TypeError):
2192+
pass
2193+
2194+
raise KeyError(f'Session ID ({sid}) does not exist')
2195+
2196+
def _create_session(self, session_id, session_data_dict, session_data):
2197+
"""PRIVATE method - maintain backward compatibility"""
2198+
stype = session_data['type']
2199+
if stype == 'meterpreter':
2200+
return MeterpreterSession(session_id, self.rpc, session_data_dict)
2201+
elif stype == 'shell':
2202+
return ShellSession(session_id, self.rpc, session_data_dict)
2203+
raise NotImplementedError(f'Unsupported session type: {stype}')
21772204

21782205

21792206
class MsfConsole(object):
@@ -2316,20 +2343,20 @@ def list(self):
23162343
return self.rpc.call(MsfRpcMethod.ConsoleList)['consoles']
23172344

23182345
def console(self, cid=None):
2319-
"""
2320-
Connect to an active console otherwise create a new console.
2321-
2322-
Optional Keyword Arguments:
2323-
- cid : the console identifier.
2324-
"""
2325-
s = [i['id'] for i in self.list]
2346+
consoles = self.list
2347+
console_ids = [c['id'] for c in consoles]
2348+
23262349
if cid is None:
23272350
return MsfConsole(self.rpc)
2328-
if cid not in s:
2329-
raise KeyError('Console ID (%s) does not exist' % cid)
2330-
else:
2351+
2352+
# Handle string representations of integers
2353+
if isinstance(cid, str) and cid.isdigit():
2354+
cid = int(cid)
2355+
2356+
if cid in console_ids:
23312357
return MsfConsole(self.rpc, cid=cid)
2332-
2358+
2359+
raise KeyError(f'Console ID ({cid}) does not exist')
23332360
def destroy(self, cid):
23342361
"""
23352362
Destroy an active console.

0 commit comments

Comments
 (0)