| 
 | 1 | +# Contains definitions for api from MPSGraphTensor.h, MPSGraphTensorData.h, MPSGraphOperation.h  | 
 | 2 | + | 
 | 3 | +## MPSGraphTensor.h  | 
 | 4 | +# @objcwrapper MPSGraphTensor <: MPSGraphObject  | 
 | 5 | + | 
 | 6 | +# Define MPSGraphOperation here to define the MPSGraphTensor properties  | 
 | 7 | +# @objcwrapper MPSGraphOperation <: MPSGraphObject  | 
 | 8 | + | 
 | 9 | +function Base.size(td::MPSGraphTensor)  | 
 | 10 | +    temp = map(td.shape) do nsnum  | 
 | 11 | +        NSNumber(reinterpret(id{NSNumber}, nsnum)).unsignedIntegerValue |> Int  | 
 | 12 | +    end  | 
 | 13 | +    Tuple(temp)  | 
 | 14 | +end  | 
 | 15 | + | 
 | 16 | +function placeholderTensor(graph::MPSGraph, shape::Union{Vector, Tuple}, args...)  | 
 | 17 | +    mpsshape = convert(MPSShape, shape)  | 
 | 18 | +    return placeholderTensor(graph, mpsshape, args...)  | 
 | 19 | +end  | 
 | 20 | +function placeholderTensor(graph::MPSGraph, shape::MPSShape, dataType::Type, name = "placeholder tensor")  | 
 | 21 | +    obj = @objc [graph::id{MPSGraph} placeholderWithShape:shape::id{MPSShape}  | 
 | 22 | +                                dataType:dataType::MPSDataType  | 
 | 23 | +                                name:name::id{NSString}]::id{MPSGraphTensor}  | 
 | 24 | +    return MPSGraphTensor(obj)  | 
 | 25 | +end  | 
 | 26 | + | 
 | 27 | +## MPSGraphTensorData.h  | 
 | 28 | +# @objcwrapper immutable=false MPSGraphTensorData <: MPSGraphObject  | 
 | 29 | + | 
 | 30 | +function Base.size(td::MPSGraphTensorData)  | 
 | 31 | +    temp = map(td.shape) do nsnum  | 
 | 32 | +        NSNumber(reinterpret(id{NSNumber}, nsnum)).unsignedIntegerValue |> Int  | 
 | 33 | +    end  | 
 | 34 | +    Tuple(temp)  | 
 | 35 | +end  | 
 | 36 | + | 
 | 37 | +function MPSGraphTensorData(buffer::MTLBuffer, shape::MPSShape, dataType)  | 
 | 38 | +    obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}  | 
 | 39 | +    tensor = MPSGraphTensorData(obj)  | 
 | 40 | +    finalizer(release, tensor)  | 
 | 41 | +    @objc [tensor::id{MPSGraphTensorData} initWithMTLBuffer:buffer::id{MTLBuffer}  | 
 | 42 | +                                    shape:shape::id{MPSShape}  | 
 | 43 | +                                    dataType:dataType::MPSDataType]::id{MPSGraphTensorData}  | 
 | 44 | +    return tensor  | 
 | 45 | +end  | 
 | 46 | +function MPSGraphTensorData(buffer::MTLBuffer, shape::MPSShape, dataType, rowBytes)  | 
 | 47 | +    obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}  | 
 | 48 | +    tensor = MPSGraphTensorData(obj)  | 
 | 49 | +    finalizer(release, tensor)  | 
 | 50 | +    @objc [tensor::id{MPSGraphTensorData} initWithMTLBuffer:buffer::id{MTLBuffer}  | 
 | 51 | +                                    shape:shape::id{MPSShape}  | 
 | 52 | +                                    dataType:dataType::MPSDataType  | 
 | 53 | +                                    rowBytes:rowBytes::NSUInteger]::id{MPSGraphTensorData}  | 
 | 54 | +    return tensor  | 
 | 55 | +end  | 
 | 56 | +# MPSGraphTensorData(matrix::MtlMatrix{T}) where T = MPSGraphTensorData(matrix.data[], convert(MPSShape, reverse(size(matrix))), T)  | 
 | 57 | +MPSGraphTensorData(matrix::MtlMatrix) = MPSGraphTensorData(MPSMatrix(matrix))  | 
 | 58 | +MPSGraphTensorData(arr::MtlArray{<:Any, 3}) = MPSGraphTensorData(MPSMatrix(arr))  | 
 | 59 | + | 
 | 60 | +function MPSGraphTensorData(matrix::MPSMatrix)  | 
 | 61 | +    obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}  | 
 | 62 | +    tensor = MPSGraphTensorData(obj)  | 
 | 63 | +    finalizer(release, tensor)  | 
 | 64 | +    @objc [tensor::id{MPSGraphTensorData} initWithMPSMatrix:matrix::id{MPSMatrix}]::id{MPSGraphTensorData}  | 
 | 65 | +    return tensor  | 
 | 66 | +end  | 
 | 67 | + | 
 | 68 | +# rank must be between 1 and 16 inclusive  | 
 | 69 | +function MPSGraphTensorData(matrix::MPSMatrix, rank)  | 
 | 70 | +    obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}  | 
 | 71 | +    tensor = MPSGraphTensorData(obj)  | 
 | 72 | +    finalizer(release, tensor)  | 
 | 73 | +    @objc [tensor::id{MPSGraphTensorData} initWithMPSMatrix:matrix::id{MPSMatrix}  | 
 | 74 | +                              rank:rank::NSUInteger]::id{MPSGraphTensorData}  | 
 | 75 | +    return tensor  | 
 | 76 | +end  | 
 | 77 | + | 
 | 78 | +function MPSGraphTensorData(vector::MPSVector)  | 
 | 79 | +    obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}  | 
 | 80 | +    tensor = MPSGraphTensorData(obj)  | 
 | 81 | +    finalizer(release, tensor)  | 
 | 82 | +    @objc [tensor::id{MPSGraphTensorData} initWithMPSVector:vector::id{MPSVector}]::id{MPSGraphTensorData}  | 
 | 83 | +    return tensor  | 
 | 84 | +end  | 
 | 85 | +MPSGraphTensorData(vector::MtlVector{T}) where T = MPSGraphTensorData(vector.data[], convert(MPSShape, size(vector)), T)  | 
 | 86 | +# MPSGraphTensorData(vector::MtlVector) = MPSGraphTensorData(MPSVector(vector))  | 
 | 87 | + | 
 | 88 | +# rank must be between 1 and 16 inclusive  | 
 | 89 | +function MPSGraphTensorData(vector::MPSVector, rank)  | 
 | 90 | +    obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}  | 
 | 91 | +    tensor = MPSGraphTensorData(obj)  | 
 | 92 | +    finalizer(release, tensor)  | 
 | 93 | +    @objc [tensor::id{MPSGraphTensorData} initWithMPSMatrix:vector::id{MPSVector}  | 
 | 94 | +                                          rank:rank::NSUInteger]::id{MPSGraphTensorData}  | 
 | 95 | +    return tensor  | 
 | 96 | +end  | 
 | 97 | + | 
 | 98 | +function MPSGraphTensorData(ndarr::MPSNDArray)  | 
 | 99 | +    obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}  | 
 | 100 | +    tensor = MPSGraphTensorData(obj)  | 
 | 101 | +    finalizer(release, tensor)  | 
 | 102 | +    @objc [tensor::id{MPSGraphTensorData} initWithMPSNDArray:ndarr::id{MPSNDArray}]::id{MPSGraphTensorData}  | 
 | 103 | +    return tensor  | 
 | 104 | +end  | 
 | 105 | +# TODO: MPSImage is not yet implemented  | 
 | 106 | +# function MPSGraphTensorData(imgbatch::MPSImageBatch)  | 
 | 107 | +#     obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}  | 
 | 108 | +#     tensor = MPSGraphTensorData(obj)  | 
 | 109 | +#     finalizer(release, tensor)  | 
 | 110 | +#     @objc [tensor::id{MPSGraphTensorData} initWithMPSImageBatch:imgbatch::id{MPSImageBatch}]::id{MPSGraphTensorData}  | 
 | 111 | +#     MPSGraphTensorData(obj)  | 
 | 112 | +# end  | 
 | 113 | + | 
 | 114 | +"""  | 
 | 115 | +    MPSNDArray(tens::MPSGraphTensorData)  | 
 | 116 | +
  | 
 | 117 | +Return an MPSNDArray object.  | 
 | 118 | +
  | 
 | 119 | +Will copy contents if the contents are not stored in an MPS ndarray.  | 
 | 120 | +"""  | 
 | 121 | +function MPS.MPSNDArray(tens::MPSGraphTensorData)  | 
 | 122 | +    arr = @objc [tens::id{MPSNDArray} mpsndarray]::id{MPSNDArray}  | 
 | 123 | +    MPSNDArray(arr)  | 
 | 124 | +end  | 
0 commit comments