@@ -306,12 +306,15 @@ def from_txout(cls, txout):
306306
307307
308308class CTxInWitness (ImmutableSerializable ):
309- """Witness data for a transaction input. """
309+ """Witness data for a single transaction input. """
310310 __slots__ = ['scriptWitness' ]
311311
312312 def __init__ (self , scriptWitness = CScriptWitness ()):
313313 object .__setattr__ (self , 'scriptWitness' , scriptWitness )
314314
315+ def is_null (self ):
316+ return self .scriptWitness .is_null ()
317+
315318 @classmethod
316319 def stream_deserialize (cls , f ):
317320 scriptWitness = CScriptWitness .stream_deserialize (f )
@@ -336,12 +339,50 @@ def from_txinwitness(cls, txinwitness):
336339 else :
337340 return cls (txinwitness .scriptWitness )
338341
342+ class CTxWitness (ImmutableSerializable ):
343+ """Witness data for all inputs to a transaction."""
344+ __slots__ = ['vtxinwit' ]
345+
346+ def __init__ (self , vtxinwit = ()):
347+ object .__setattr__ (self , 'vtxinwit' , vtxinwit )
348+
349+ def is_null (self ):
350+ for n in range (len (self .vtxinwit )):
351+ if not self .vtxinwit [n ].is_null (): return False
352+ return True
353+
354+ # FIXME this cannot be a @classmethod like the others because we need to
355+ # know how many items to deserialize, which comes from len(vin)
356+ def stream_deserialize (self , f ):
357+ vtxinwit = tuple (CTxInWitness .stream_deserialize (f ) for dummy in
358+ range (len (self .vtxinwit )))
359+ return CTxWitness (vtxinwit )
360+
361+ def stream_serialize (self , f ):
362+ for i in range (len (self .vtxinwit )):
363+ self .vtxinwit [i ].stream_serialize (f )
364+
365+ def __repr__ (self ):
366+ return "CTxWitness(%s)" % (',' .join (repr (w ) for w in self .vtxinwit ))
367+
368+ @classmethod
369+ def from_txwitness (cls , txwitness ):
370+ """Create an immutable copy of an existing TxWitness
371+
372+ If txwitness is already immutable (txwitness.__class__ is CTxWitness) it is returned
373+ directly.
374+ """
375+ if txwitness .__class__ is CTxWitness :
376+ return txwitness
377+ else :
378+ return cls (txwitness .vtxinwit )
379+
339380
340381class CTransaction (ImmutableSerializable ):
341382 """A transaction"""
342383 __slots__ = ['nVersion' , 'vin' , 'vout' , 'nLockTime' , 'wit' ]
343384
344- def __init__ (self , vin = (), vout = (), nLockTime = 0 , nVersion = 1 , witness = ()):
385+ def __init__ (self , vin = (), vout = (), nLockTime = 0 , nVersion = 1 , witness = CTxWitness ()):
345386 """Create a new transaction
346387
347388 vin and vout are iterables of transaction inputs and outputs
@@ -351,13 +392,10 @@ def __init__(self, vin=(), vout=(), nLockTime=0, nVersion=1, witness=()):
351392 if not (0 <= nLockTime <= 0xffffffff ):
352393 raise ValueError ('CTransaction: nLockTime must be in range 0x0 to 0xffffffff; got %x' % nLockTime )
353394 object .__setattr__ (self , 'nLockTime' , nLockTime )
354-
355395 object .__setattr__ (self , 'nVersion' , nVersion )
356396 object .__setattr__ (self , 'vin' , tuple (CTxIn .from_txin (txin ) for txin in vin ))
357397 object .__setattr__ (self , 'vout' , tuple (CTxOut .from_txout (txout ) for txout in vout ))
358- object .__setattr__ (self , 'wit' ,
359- tuple (CTxInWitness .from_txinwitness (witness ) for txinwitness in
360- witness ))
398+ object .__setattr__ (self , 'wit' , CTxWitness .from_txwitness (witness ))
361399
362400 @classmethod
363401 def stream_deserialize (cls , f ):
@@ -370,7 +408,8 @@ def stream_deserialize(cls, f):
370408 raise DeserializationFormatError
371409 vin = VectorSerializer .stream_deserialize (CTxIn , f )
372410 vout = VectorSerializer .stream_deserialize (CTxOut , f )
373- wit = VectorSerializer .stream_deserialize (CTxInWitness , f )
411+ wit = CTxWitness (tuple (0 for dummy in range (len (vin ))))
412+ wit = wit .stream_deserialize (f )
374413 nLockTime = struct .unpack (b"<I" , ser_read (f ,4 ))[0 ]
375414 return cls (vin , vout , nLockTime , nVersion , wit )
376415 else :
@@ -382,15 +421,15 @@ def stream_deserialize(cls, f):
382421
383422
384423 def stream_serialize (self , f ):
385- if self .wit :
386- if len (self .wit ) != len (self .vin ):
424+ if not self .wit . is_null () :
425+ if len (self .wit . vtxinwit ) != len (self .vin ):
387426 raise SerializationMissingWitnessError
388427 f .write (struct .pack (b"<i" , self .nVersion ))
389428 f .write (b'\x00 ' ) # Marker
390429 f .write (b'\x01 ' ) # Flag
391430 VectorSerializer .stream_serialize (CTxIn , self .vin , f )
392431 VectorSerializer .stream_serialize (CTxOut , self .vout , f )
393- for w in self .wit : w .stream_serialize (f )
432+ self .wit .stream_serialize (f )
394433 f .write (struct .pack (b"<I" , self .nLockTime ))
395434 else :
396435 f .write (struct .pack (b"<i" , self .nVersion ))
@@ -422,11 +461,9 @@ def GetTxid(self):
422461 """Get the transaction ID. This differs from the transactions hash as
423462 given by GetHash. GetTxid excludes witness data, while GetHash
424463 includes it. """
425- if self .wit :
426- wit = self .wit
427- self .wit = b''
428- txid = Hash (self .serialize ())
429- self .wit = wit
464+ if self .wit != CTxWitness ():
465+ txid = Hash (CTransaction (self .vin , self .vout , self .nLockTime ,
466+ self .nVersion ).serialize ())
430467 else :
431468 txid = Hash (self .serialize ())
432469 return txid
@@ -452,6 +489,9 @@ def __init__(self, vin=None, vout=None, nLockTime=0, nVersion=1, witness=None):
452489 vout = []
453490 self .vout = vout
454491 self .nVersion = nVersion
492+
493+ if witness is None :
494+ witness = CTxWitness ([CTxInWitness () for dummy in range (len (vin ))])
455495 self .wit = witness
456496
457497 @classmethod
0 commit comments